5b5d8aa50ce4e1f6887b8cd7dd34a53b5bfd54f1
[linpy.git] / pypol / linear.py
1
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7
8 __all__ = [
9 'Expression',
10 'constant', 'symbol', 'symbols',
11 'eq', 'le', 'lt', 'ge', 'gt',
12 'Polyhedron',
13 'empty', 'universe'
14 ]
15
16
17 class Expression:
18 """
19 This class implements linear expressions.
20 """
21
22 def __new__(cls, coefficients=None, constant=0):
23 if isinstance(coefficients, str):
24 if constant:
25 raise TypeError('too many arguments')
26 return cls.fromstring(coefficients)
27 self = super().__new__(cls)
28 self._coefficients = {}
29 if isinstance(coefficients, dict):
30 coefficients = coefficients.items()
31 if coefficients is not None:
32 for symbol, coefficient in coefficients:
33 if isinstance(symbol, Expression) and symbol.issymbol():
34 symbol = str(symbol)
35 elif not isinstance(symbol, str):
36 raise TypeError('symbols must be strings')
37 if not isinstance(coefficient, numbers.Rational):
38 raise TypeError('coefficients must be rational numbers')
39 if coefficient != 0:
40 self._coefficients[symbol] = coefficient
41 if not isinstance(constant, numbers.Rational):
42 raise TypeError('constant must be a rational number')
43 self._constant = constant
44 return self
45
46 def symbols(self):
47 yield from sorted(self._coefficients)
48
49 @property
50 def dimension(self):
51 return len(list(self.symbols()))
52
53 def coefficient(self, symbol):
54 if isinstance(symbol, Expression) and symbol.issymbol():
55 symbol = str(symbol)
56 elif not isinstance(symbol, str):
57 raise TypeError('symbol must be a string')
58 try:
59 return self._coefficients[symbol]
60 except KeyError:
61 return 0
62
63 __getitem__ = coefficient
64
65 def coefficients(self):
66 for symbol in self.symbols():
67 yield symbol, self.coefficient(symbol)
68
69 @property
70 def constant(self):
71 return self._constant
72
73 def isconstant(self):
74 return len(self._coefficients) == 0
75
76 def values(self):
77 for symbol in self.symbols():
78 yield self.coefficient(symbol)
79 yield self.constant
80
81 def symbol(self):
82 if not self.issymbol():
83 raise ValueError('not a symbol: {}'.format(self))
84 for symbol in self.symbols():
85 return symbol
86
87 def issymbol(self):
88 return len(self._coefficients) == 1 and self._constant == 0
89
90 def __bool__(self):
91 return (not self.isconstant()) or bool(self.constant)
92
93 def __pos__(self):
94 return self
95
96 def __neg__(self):
97 return self * -1
98
99 def _polymorphic(func):
100 @functools.wraps(func)
101 def wrapper(self, other):
102 if isinstance(other, Expression):
103 return func(self, other)
104 if isinstance(other, numbers.Rational):
105 other = Expression(constant=other)
106 return func(self, other)
107 return NotImplemented
108 return wrapper
109
110 @_polymorphic
111 def __add__(self, other):
112 coefficients = dict(self.coefficients())
113 for symbol, coefficient in other.coefficients():
114 if symbol in coefficients:
115 coefficients[symbol] += coefficient
116 else:
117 coefficients[symbol] = coefficient
118 constant = self.constant + other.constant
119 return Expression(coefficients, constant)
120
121 __radd__ = __add__
122
123 @_polymorphic
124 def __sub__(self, other):
125 coefficients = dict(self.coefficients())
126 for symbol, coefficient in other.coefficients():
127 if symbol in coefficients:
128 coefficients[symbol] -= coefficient
129 else:
130 coefficients[symbol] = -coefficient
131 constant = self.constant - other.constant
132 return Expression(coefficients, constant)
133
134 __rsub__ = __sub__
135
136 @_polymorphic
137 def __mul__(self, other):
138 if other.isconstant():
139 coefficients = dict(self.coefficients())
140 for symbol in coefficients:
141 coefficients[symbol] *= other.constant
142 constant = self.constant * other.constant
143 return Expression(coefficients, constant)
144 if isinstance(other, Expression) and not self.isconstant():
145 raise ValueError('non-linear expression: '
146 '{} * {}'.format(self._parenstr(), other._parenstr()))
147 return NotImplemented
148
149 __rmul__ = __mul__
150
151 @_polymorphic
152 def __truediv__(self, other):
153 if other.isconstant():
154 coefficients = dict(self.coefficients())
155 for symbol in coefficients:
156 coefficients[symbol] = \
157 Fraction(coefficients[symbol], other.constant)
158 constant = Fraction(self.constant, other.constant)
159 return Expression(coefficients, constant)
160 if isinstance(other, Expression):
161 raise ValueError('non-linear expression: '
162 '{} / {}'.format(self._parenstr(), other._parenstr()))
163 return NotImplemented
164
165 def __rtruediv__(self, other):
166 if isinstance(other, Rational):
167 if self.isconstant():
168 constant = Fraction(other, self.constant)
169 return Expression(constant=constant)
170 else:
171 raise ValueError('non-linear expression: '
172 '{} / {}'.format(other._parenstr(), self._parenstr()))
173 return NotImplemented
174
175 def __str__(self):
176 string = ''
177 symbols = sorted(self.symbols())
178 i = 0
179 for symbol in symbols:
180 coefficient = self[symbol]
181 if coefficient == 1:
182 if i == 0:
183 string += symbol
184 else:
185 string += ' + {}'.format(symbol)
186 elif coefficient == -1:
187 if i == 0:
188 string += '-{}'.format(symbol)
189 else:
190 string += ' - {}'.format(symbol)
191 else:
192 if i == 0:
193 string += '{}*{}'.format(coefficient, symbol)
194 elif coefficient > 0:
195 string += ' + {}*{}'.format(coefficient, symbol)
196 else:
197 assert coefficient < 0
198 coefficient *= -1
199 string += ' - {}*{}'.format(coefficient, symbol)
200 i += 1
201 constant = self.constant
202 if constant != 0 and i == 0:
203 string += '{}'.format(constant)
204 elif constant > 0:
205 string += ' + {}'.format(constant)
206 elif constant < 0:
207 constant *= -1
208 string += ' - {}'.format(constant)
209 return string
210
211 def _parenstr(self, always=False):
212 string = str(self)
213 if not always and (self.isconstant() or self.issymbol()):
214 return string
215 else:
216 return '({})'.format(string)
217
218 def __repr__(self):
219 string = '{}({{'.format(self.__class__.__name__)
220 for i, (symbol, coefficient) in enumerate(self.coefficients()):
221 if i != 0:
222 string += ', '
223 string += '{!r}: {!r}'.format(symbol, coefficient)
224 string += '}}, {!r})'.format(self.constant)
225 return string
226
227 @classmethod
228 def fromstring(cls, string):
229 raise NotImplementedError
230
231 @_polymorphic
232 def __eq__(self, other):
233 # "normal" equality
234 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
235 return isinstance(other, Expression) and \
236 self._coefficients == other._coefficients and \
237 self.constant == other.constant
238
239 def __hash__(self):
240 return hash((self._coefficients, self._constant))
241
242 def _canonify(self):
243 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
244 [value.denominator for value in self.values()])
245 return self * lcm
246
247 @_polymorphic
248 def _eq(self, other):
249 return Polyhedron(equalities=[(self - other)._canonify()])
250
251 @_polymorphic
252 def __le__(self, other):
253 return Polyhedron(inequalities=[(self - other)._canonify()])
254
255 @_polymorphic
256 def __lt__(self, other):
257 return Polyhedron(inequalities=[(self - other)._canonify() + 1])
258
259 @_polymorphic
260 def __ge__(self, other):
261 return Polyhedron(inequalities=[(other - self)._canonify()])
262
263 @_polymorphic
264 def __gt__(self, other):
265 return Polyhedron(inequalities=[(other - self)._canonify() + 1])
266
267
268 def constant(numerator=0, denominator=None):
269 return Expression(constant=Fraction(numerator, denominator))
270
271 def symbol(name):
272 if not isinstance(name, str):
273 raise TypeError('name must be a string')
274 return Expression(coefficients={name: 1})
275
276 def symbols(names):
277 if isinstance(names, str):
278 names = names.replace(',', ' ').split()
279 return (symbol(name) for name in names)
280
281
282 def _operator(func):
283 @functools.wraps(func)
284 def wrapper(a, b):
285 if isinstance(a, numbers.Rational):
286 a = constant(a)
287 if isinstance(b, numbers.Rational):
288 b = constant(b)
289 if isinstance(a, Expression) and isinstance(b, Expression):
290 return func(a, b)
291 raise TypeError('arguments must be linear expressions')
292 return wrapper
293
294 @_operator
295 def eq(a, b):
296 return a._eq(b)
297
298 @_operator
299 def le(a, b):
300 return a <= b
301
302 @_operator
303 def lt(a, b):
304 return a < b
305
306 @_operator
307 def ge(a, b):
308 return a >= b
309
310 @_operator
311 def gt(a, b):
312 return a > b
313
314
315 class Polyhedron:
316 """
317 This class implements polyhedrons.
318 """
319
320 def __new__(cls, equalities=None, inequalities=None):
321 if isinstance(equalities, str):
322 if inequalities is not None:
323 raise TypeError('too many arguments')
324 return cls.fromstring(equalities)
325 self = super().__new__(cls)
326 self._equalities = []
327 if equalities is not None:
328 for constraint in equalities:
329 for value in constraint.values():
330 if value.denominator != 1:
331 raise TypeError('non-integer constraint: '
332 '{} == 0'.format(constraint))
333 self._equalities.append(constraint)
334 self._inequalities = []
335 if inequalities is not None:
336 for constraint in inequalities:
337 for value in constraint.values():
338 if value.denominator != 1:
339 raise TypeError('non-integer constraint: '
340 '{} <= 0'.format(constraint))
341 self._inequalities.append(constraint)
342 return self
343
344 @property
345 def equalities(self):
346 yield from self._equalities
347
348 @property
349 def inequalities(self):
350 yield from self._inequalities
351
352 def constraints(self):
353 yield from self.equalities
354 yield from self.inequalities
355
356 def symbols(self):
357 s = set()
358 for constraint in self.constraints():
359 s.update(constraint.symbols)
360 yield from sorted(s)
361
362 @property
363 def dimension(self):
364 return len(self.symbols())
365
366 def __bool__(self):
367 # return false if the polyhedron is empty, true otherwise
368 raise NotImplementedError
369
370 def __contains__(self, value):
371 # is the value in the polyhedron?
372 raise NotImplementedError
373
374 def __eq__(self, other):
375 raise NotImplementedError
376
377 def isempty(self):
378 return self == empty
379
380 def isuniverse(self):
381 return self == universe
382
383 def isdisjoint(self, other):
384 # return true if the polyhedron has no elements in common with other
385 raise NotImplementedError
386
387 def issubset(self, other):
388 raise NotImplementedError
389
390 def __le__(self, other):
391 return self.issubset(other)
392
393 def __lt__(self, other):
394 raise NotImplementedError
395
396 def issuperset(self, other):
397 # test whether every element in other is in the polyhedron
398 raise NotImplementedError
399
400 def __ge__(self, other):
401 return self.issuperset(other)
402
403 def __gt__(self, other):
404 raise NotImplementedError
405
406 def union(self, *others):
407 # return a new polyhedron with elements from the polyhedron and all
408 # others (convex union)
409 raise NotImplementedError
410
411 def __or__(self, other):
412 return self.union(other)
413
414 def intersection(self, *others):
415 # return a new polyhedron with elements common to the polyhedron and all
416 # others
417 # a poor man's implementation could be:
418 # equalities = list(self.equalities)
419 # inequalities = list(self.inequalities)
420 # for other in others:
421 # equalities.extend(other.equalities)
422 # inequalities.extend(other.inequalities)
423 # return self.__class__(equalities, inequalities)
424 raise NotImplementedError
425
426 def __and__(self, other):
427 return self.intersection(other)
428
429 def difference(self, *others):
430 # return a new polyhedron with elements in the polyhedron that are not
431 # in the others
432 raise NotImplementedError
433
434 def __sub__(self, other):
435 return self.difference(other)
436
437 def __str__(self):
438 constraints = []
439 for constraint in self.equalities:
440 constraints.append('{} == 0'.format(constraint))
441 for constraint in self.inequalities:
442 constraints.append('{} <= 0'.format(constraint))
443 return '{{{}}}'.format(', '.join(constraints))
444
445 def __repr__(self):
446 equalities = list(self.equalities)
447 inequalities = list(self.inequalities)
448 return '{}(equalities={!r}, inequalities={!r})' \
449 ''.format(self.__class__.__name__, equalities, inequalities)
450
451 @classmethod
452 def fromstring(cls, string):
453 raise NotImplementedError
454
455
456 empty = le(1, 0)
457
458 universe = Polyhedron()