fabf2a2d18df95413e1257a19d65eeea41881fdd
5 from fractions
import Fraction
, gcd
10 'constant', 'symbol', 'symbols',
11 'eq', 'le', 'lt', 'ge', 'gt',
17 def _polymorphic_method(func
):
18 @functools.wraps(func
)
20 if isinstance(b
, Expression
):
22 if isinstance(b
, numbers
.Rational
):
28 def _polymorphic_operator(func
):
29 @functools.wraps(func
)
31 if isinstance(a
, numbers
.Rational
):
33 if isinstance(b
, numbers
.Rational
):
35 if isinstance(a
, Expression
) and isinstance(b
, Expression
):
37 raise TypeError('arguments must be linear expressions')
43 This class implements linear expressions.
46 def __new__(cls
, coefficients
=None, constant
=0):
47 if isinstance(coefficients
, str):
49 raise TypeError('too many arguments')
50 return cls
.fromstring(coefficients
)
51 self
= super().__new
__(cls
)
52 self
._coefficients
= {}
53 if isinstance(coefficients
, dict):
54 coefficients
= coefficients
.items()
55 if coefficients
is not None:
56 for symbol
, coefficient
in coefficients
:
57 if isinstance(symbol
, Expression
) and symbol
.issymbol():
59 elif not isinstance(symbol
, str):
60 raise TypeError('symbols must be strings')
61 if not isinstance(coefficient
, numbers
.Rational
):
62 raise TypeError('coefficients must be rational numbers')
64 self
._coefficients
[symbol
] = coefficient
65 if not isinstance(constant
, numbers
.Rational
):
66 raise TypeError('constant must be a rational number')
67 self
._constant
= constant
71 yield from sorted(self
._coefficients
)
75 return len(list(self
.symbols()))
77 def coefficient(self
, symbol
):
78 if isinstance(symbol
, Expression
) and symbol
.issymbol():
80 elif not isinstance(symbol
, str):
81 raise TypeError('symbol must be a string')
83 return self
._coefficients
[symbol
]
87 __getitem__
= coefficient
89 def coefficients(self
):
90 for symbol
in self
.symbols():
91 yield symbol
, self
.coefficient(symbol
)
98 return len(self
._coefficients
) == 0
101 for symbol
in self
.symbols():
102 yield self
.coefficient(symbol
)
106 if not self
.issymbol():
107 raise ValueError('not a symbol: {}'.format(self
))
108 for symbol
in self
.symbols():
112 return len(self
._coefficients
) == 1 and self
._constant
== 0
115 return (not self
.isconstant()) or bool(self
.constant
)
124 def __add__(self
, other
):
125 coefficients
= dict(self
.coefficients())
126 for symbol
, coefficient
in other
.coefficients():
127 if symbol
in coefficients
:
128 coefficients
[symbol
] += coefficient
130 coefficients
[symbol
] = coefficient
131 constant
= self
.constant
+ other
.constant
132 return Expression(coefficients
, constant
)
137 def __sub__(self
, other
):
138 coefficients
= dict(self
.coefficients())
139 for symbol
, coefficient
in other
.coefficients():
140 if symbol
in coefficients
:
141 coefficients
[symbol
] -= coefficient
143 coefficients
[symbol
] = -coefficient
144 constant
= self
.constant
- other
.constant
145 return Expression(coefficients
, constant
)
150 def __mul__(self
, other
):
151 if other
.isconstant():
152 coefficients
= dict(self
.coefficients())
153 for symbol
in coefficients
:
154 coefficients
[symbol
] *= other
.constant
155 constant
= self
.constant
* other
.constant
156 return Expression(coefficients
, constant
)
157 if isinstance(other
, Expression
) and not self
.isconstant():
158 raise ValueError('non-linear expression: '
159 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
160 return NotImplemented
165 def __truediv__(self
, other
):
166 if other
.isconstant():
167 coefficients
= dict(self
.coefficients())
168 for symbol
in coefficients
:
169 coefficients
[symbol
] = \
170 Fraction(coefficients
[symbol
], other
.constant
)
171 constant
= Fraction(self
.constant
, other
.constant
)
172 return Expression(coefficients
, constant
)
173 if isinstance(other
, Expression
):
174 raise ValueError('non-linear expression: '
175 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
176 return NotImplemented
178 def __rtruediv__(self
, other
):
179 if isinstance(other
, Rational
):
180 if self
.isconstant():
181 constant
= Fraction(other
, self
.constant
)
182 return Expression(constant
=constant
)
184 raise ValueError('non-linear expression: '
185 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
186 return NotImplemented
190 symbols
= sorted(self
.symbols())
192 for symbol
in symbols
:
193 coefficient
= self
[symbol
]
198 string
+= ' + {}'.format(symbol
)
199 elif coefficient
== -1:
201 string
+= '-{}'.format(symbol
)
203 string
+= ' - {}'.format(symbol
)
206 string
+= '{}*{}'.format(coefficient
, symbol
)
207 elif coefficient
> 0:
208 string
+= ' + {}*{}'.format(coefficient
, symbol
)
210 assert coefficient
< 0
212 string
+= ' - {}*{}'.format(coefficient
, symbol
)
214 constant
= self
.constant
215 if constant
!= 0 and i
== 0:
216 string
+= '{}'.format(constant
)
218 string
+= ' + {}'.format(constant
)
221 string
+= ' - {}'.format(constant
)
226 def _parenstr(self
, always
=False):
228 if not always
and (self
.isconstant() or self
.issymbol()):
231 return '({})'.format(string
)
234 string
= '{}({{'.format(self
.__class
__.__name
__)
235 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
238 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
239 string
+= '}}, {!r})'.format(self
.constant
)
243 def fromstring(cls
, string
):
244 raise NotImplementedError
247 def __eq__(self
, other
):
249 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
250 return isinstance(other
, Expression
) and \
251 self
._coefficients
== other
._coefficients
and \
252 self
.constant
== other
.constant
255 return hash((self
._coefficients
, self
._constant
))
258 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
259 [value
.denominator
for value
in self
.values()])
263 def _eq(self
, other
):
264 return Polyhedron(equalities
=[(self
- other
)._canonify
()])
267 def __le__(self
, other
):
268 return Polyhedron(inequalities
=[(self
- other
)._canonify
()])
271 def __lt__(self
, other
):
272 return Polyhedron(inequalities
=[(self
- other
)._canonify
() + 1])
275 def __ge__(self
, other
):
276 return Polyhedron(inequalities
=[(other
- self
)._canonify
()])
279 def __gt__(self
, other
):
280 return Polyhedron(inequalities
=[(other
- self
)._canonify
() + 1])
283 def constant(numerator
=0, denominator
=None):
284 if denominator
is None and isinstance(numerator
, numbers
.Rational
):
285 return Expression(constant
=numerator
)
287 return Expression(constant
=Fraction(numerator
, denominator
))
290 if not isinstance(name
, str):
291 raise TypeError('name must be a string')
292 return Expression(coefficients
={name
: 1})
295 if isinstance(names
, str):
296 names
= names
.replace(',', ' ').split()
297 return (symbol(name
) for name
in names
)
300 @_polymorphic_operator
304 @_polymorphic_operator
308 @_polymorphic_operator
312 @_polymorphic_operator
316 @_polymorphic_operator
323 This class implements polyhedrons.
326 def __new__(cls
, equalities
=None, inequalities
=None):
327 if isinstance(equalities
, str):
328 if inequalities
is not None:
329 raise TypeError('too many arguments')
330 return cls
.fromstring(equalities
)
331 self
= super().__new
__(cls
)
332 self
._equalities
= []
333 if equalities
is not None:
334 for constraint
in equalities
:
335 for value
in constraint
.values():
336 if value
.denominator
!= 1:
337 raise TypeError('non-integer constraint: '
338 '{} == 0'.format(constraint
))
339 self
._equalities
.append(constraint
)
340 self
._inequalities
= []
341 if inequalities
is not None:
342 for constraint
in inequalities
:
343 for value
in constraint
.values():
344 if value
.denominator
!= 1:
345 raise TypeError('non-integer constraint: '
346 '{} <= 0'.format(constraint
))
347 self
._inequalities
.append(constraint
)
351 def equalities(self
):
352 yield from self
._equalities
355 def inequalities(self
):
356 yield from self
._inequalities
358 def constraints(self
):
359 yield from self
.equalities
360 yield from self
.inequalities
364 for constraint
in self
.constraints():
365 s
.update(constraint
.symbols
)
370 return len(self
.symbols())
373 # return false if the polyhedron is empty, true otherwise
374 raise NotImplementedError
376 def __contains__(self
, value
):
377 # is the value in the polyhedron?
378 raise NotImplementedError
380 def __eq__(self
, other
):
381 raise NotImplementedError
386 def isuniverse(self
):
387 return self
== universe
389 def isdisjoint(self
, other
):
390 # return true if the polyhedron has no elements in common with other
391 raise NotImplementedError
393 def issubset(self
, other
):
394 raise NotImplementedError
396 def __le__(self
, other
):
397 return self
.issubset(other
)
399 def __lt__(self
, other
):
400 raise NotImplementedError
402 def issuperset(self
, other
):
403 # test whether every element in other is in the polyhedron
404 raise NotImplementedError
406 def __ge__(self
, other
):
407 return self
.issuperset(other
)
409 def __gt__(self
, other
):
410 raise NotImplementedError
412 def union(self
, *others
):
413 # return a new polyhedron with elements from the polyhedron and all
414 # others (convex union)
415 raise NotImplementedError
417 def __or__(self
, other
):
418 return self
.union(other
)
420 def intersection(self
, *others
):
421 # return a new polyhedron with elements common to the polyhedron and all
423 # a poor man's implementation could be:
424 # equalities = list(self.equalities)
425 # inequalities = list(self.inequalities)
426 # for other in others:
427 # equalities.extend(other.equalities)
428 # inequalities.extend(other.inequalities)
429 # return self.__class__(equalities, inequalities)
430 raise NotImplementedError
432 def __and__(self
, other
):
433 return self
.intersection(other
)
435 def difference(self
, *others
):
436 # return a new polyhedron with elements in the polyhedron that are not
438 raise NotImplementedError
440 def __sub__(self
, other
):
441 return self
.difference(other
)
445 for constraint
in self
.equalities
:
446 constraints
.append('{} == 0'.format(constraint
))
447 for constraint
in self
.inequalities
:
448 constraints
.append('{} <= 0'.format(constraint
))
449 return '{{{}}}'.format(', '.join(constraints
))
452 equalities
= list(self
.equalities
)
453 inequalities
= list(self
.inequalities
)
454 return '{}(equalities={!r}, inequalities={!r})' \
455 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
458 def fromstring(cls
, string
):
459 raise NotImplementedError
464 universe
= Polyhedron()