a5f55fa68fda5865d579b88329f15c995f8c082a
5 from fractions
import Fraction
, gcd
10 'constant', 'symbol', 'symbols',
11 'eq', 'le', 'lt', 'ge', 'gt',
17 def _polymorphic_method(func
):
18 @functools.wraps(func
)
20 if isinstance(b
, Expression
):
22 if isinstance(b
, numbers
.Rational
):
28 def _polymorphic_operator(func
):
29 # A polymorphic operator should call a polymorphic method, hence we just
30 # have to test the left operand.
31 @functools.wraps(func
)
33 if isinstance(a
, numbers
.Rational
):
36 elif isinstance(a
, Expression
):
38 raise TypeError('arguments must be linear expressions')
44 This class implements linear expressions.
47 def __new__(cls
, coefficients
=None, constant
=0):
48 if isinstance(coefficients
, str):
50 raise TypeError('too many arguments')
51 return cls
.fromstring(coefficients
)
52 self
= super().__new
__(cls
)
53 self
._coefficients
= {}
54 if isinstance(coefficients
, dict):
55 coefficients
= coefficients
.items()
56 if coefficients
is not None:
57 for symbol
, coefficient
in coefficients
:
58 if isinstance(symbol
, Expression
) and symbol
.issymbol():
60 elif not isinstance(symbol
, str):
61 raise TypeError('symbols must be strings')
62 if not isinstance(coefficient
, numbers
.Rational
):
63 raise TypeError('coefficients must be rational numbers')
65 self
._coefficients
[symbol
] = coefficient
66 if not isinstance(constant
, numbers
.Rational
):
67 raise TypeError('constant must be a rational number')
68 self
._constant
= constant
72 yield from sorted(self
._coefficients
)
76 return len(list(self
.symbols()))
78 def coefficient(self
, symbol
):
79 if isinstance(symbol
, Expression
) and symbol
.issymbol():
81 elif not isinstance(symbol
, str):
82 raise TypeError('symbol must be a string')
84 return self
._coefficients
[symbol
]
88 __getitem__
= coefficient
90 def coefficients(self
):
91 for symbol
in self
.symbols():
92 yield symbol
, self
.coefficient(symbol
)
99 return len(self
._coefficients
) == 0
102 for symbol
in self
.symbols():
103 yield self
.coefficient(symbol
)
107 if not self
.issymbol():
108 raise ValueError('not a symbol: {}'.format(self
))
109 for symbol
in self
.symbols():
113 return len(self
._coefficients
) == 1 and self
._constant
== 0
116 return (not self
.isconstant()) or bool(self
.constant
)
125 def __add__(self
, other
):
126 coefficients
= dict(self
.coefficients())
127 for symbol
, coefficient
in other
.coefficients():
128 if symbol
in coefficients
:
129 coefficients
[symbol
] += coefficient
131 coefficients
[symbol
] = coefficient
132 constant
= self
.constant
+ other
.constant
133 return Expression(coefficients
, constant
)
138 def __sub__(self
, other
):
139 coefficients
= dict(self
.coefficients())
140 for symbol
, coefficient
in other
.coefficients():
141 if symbol
in coefficients
:
142 coefficients
[symbol
] -= coefficient
144 coefficients
[symbol
] = -coefficient
145 constant
= self
.constant
- other
.constant
146 return Expression(coefficients
, constant
)
151 def __mul__(self
, other
):
152 if other
.isconstant():
153 coefficients
= dict(self
.coefficients())
154 for symbol
in coefficients
:
155 coefficients
[symbol
] *= other
.constant
156 constant
= self
.constant
* other
.constant
157 return Expression(coefficients
, constant
)
158 if isinstance(other
, Expression
) and not self
.isconstant():
159 raise ValueError('non-linear expression: '
160 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
161 return NotImplemented
166 def __truediv__(self
, other
):
167 if other
.isconstant():
168 coefficients
= dict(self
.coefficients())
169 for symbol
in coefficients
:
170 coefficients
[symbol
] = \
171 Fraction(coefficients
[symbol
], other
.constant
)
172 constant
= Fraction(self
.constant
, other
.constant
)
173 return Expression(coefficients
, constant
)
174 if isinstance(other
, Expression
):
175 raise ValueError('non-linear expression: '
176 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
177 return NotImplemented
179 def __rtruediv__(self
, other
):
180 if isinstance(other
, Rational
):
181 if self
.isconstant():
182 constant
= Fraction(other
, self
.constant
)
183 return Expression(constant
=constant
)
185 raise ValueError('non-linear expression: '
186 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
187 return NotImplemented
191 symbols
= sorted(self
.symbols())
193 for symbol
in symbols
:
194 coefficient
= self
[symbol
]
199 string
+= ' + {}'.format(symbol
)
200 elif coefficient
== -1:
202 string
+= '-{}'.format(symbol
)
204 string
+= ' - {}'.format(symbol
)
207 string
+= '{}*{}'.format(coefficient
, symbol
)
208 elif coefficient
> 0:
209 string
+= ' + {}*{}'.format(coefficient
, symbol
)
211 assert coefficient
< 0
213 string
+= ' - {}*{}'.format(coefficient
, symbol
)
215 constant
= self
.constant
216 if constant
!= 0 and i
== 0:
217 string
+= '{}'.format(constant
)
219 string
+= ' + {}'.format(constant
)
222 string
+= ' - {}'.format(constant
)
227 def _parenstr(self
, always
=False):
229 if not always
and (self
.isconstant() or self
.issymbol()):
232 return '({})'.format(string
)
235 string
= '{}({{'.format(self
.__class
__.__name
__)
236 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
239 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
240 string
+= '}}, {!r})'.format(self
.constant
)
244 def fromstring(cls
, string
):
245 raise NotImplementedError
248 def __eq__(self
, other
):
250 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
251 return isinstance(other
, Expression
) and \
252 self
._coefficients
== other
._coefficients
and \
253 self
.constant
== other
.constant
256 return hash((self
._coefficients
, self
._constant
))
259 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
260 [value
.denominator
for value
in self
.values()])
264 def _eq(self
, other
):
265 return Polyhedron(equalities
=[(self
- other
)._canonify
()])
268 def __le__(self
, other
):
269 return Polyhedron(inequalities
=[(self
- other
)._canonify
()])
272 def __lt__(self
, other
):
273 return Polyhedron(inequalities
=[(self
- other
)._canonify
() + 1])
276 def __ge__(self
, other
):
277 return Polyhedron(inequalities
=[(other
- self
)._canonify
()])
280 def __gt__(self
, other
):
281 return Polyhedron(inequalities
=[(other
- self
)._canonify
() + 1])
284 def constant(numerator
=0, denominator
=None):
285 if denominator
is None and isinstance(numerator
, numbers
.Rational
):
286 return Expression(constant
=numerator
)
288 return Expression(constant
=Fraction(numerator
, denominator
))
291 if not isinstance(name
, str):
292 raise TypeError('name must be a string')
293 return Expression(coefficients
={name
: 1})
296 if isinstance(names
, str):
297 names
= names
.replace(',', ' ').split()
298 return (symbol(name
) for name
in names
)
301 @_polymorphic_operator
305 @_polymorphic_operator
309 @_polymorphic_operator
313 @_polymorphic_operator
317 @_polymorphic_operator
324 This class implements polyhedrons.
327 def __new__(cls
, equalities
=None, inequalities
=None):
328 if isinstance(equalities
, str):
329 if inequalities
is not None:
330 raise TypeError('too many arguments')
331 return cls
.fromstring(equalities
)
332 self
= super().__new
__(cls
)
333 self
._equalities
= []
334 if equalities
is not None:
335 for constraint
in equalities
:
336 for value
in constraint
.values():
337 if value
.denominator
!= 1:
338 raise TypeError('non-integer constraint: '
339 '{} == 0'.format(constraint
))
340 self
._equalities
.append(constraint
)
341 self
._inequalities
= []
342 if inequalities
is not None:
343 for constraint
in inequalities
:
344 for value
in constraint
.values():
345 if value
.denominator
!= 1:
346 raise TypeError('non-integer constraint: '
347 '{} <= 0'.format(constraint
))
348 self
._inequalities
.append(constraint
)
352 def equalities(self
):
353 yield from self
._equalities
356 def inequalities(self
):
357 yield from self
._inequalities
359 def constraints(self
):
360 yield from self
.equalities
361 yield from self
.inequalities
365 for constraint
in self
.constraints():
366 s
.update(constraint
.symbols
)
371 return len(self
.symbols())
374 # return false if the polyhedron is empty, true otherwise
375 raise NotImplementedError
377 def __contains__(self
, value
):
378 # is the value in the polyhedron?
379 raise NotImplementedError
381 def __eq__(self
, other
):
382 raise NotImplementedError
387 def isuniverse(self
):
388 return self
== universe
390 def isdisjoint(self
, other
):
391 # return true if the polyhedron has no elements in common with other
392 raise NotImplementedError
394 def issubset(self
, other
):
395 raise NotImplementedError
397 def __le__(self
, other
):
398 return self
.issubset(other
)
400 def __lt__(self
, other
):
401 raise NotImplementedError
403 def issuperset(self
, other
):
404 # test whether every element in other is in the polyhedron
405 raise NotImplementedError
407 def __ge__(self
, other
):
408 return self
.issuperset(other
)
410 def __gt__(self
, other
):
411 raise NotImplementedError
413 def union(self
, *others
):
414 # return a new polyhedron with elements from the polyhedron and all
415 # others (convex union)
416 raise NotImplementedError
418 def __or__(self
, other
):
419 return self
.union(other
)
421 def intersection(self
, *others
):
422 # return a new polyhedron with elements common to the polyhedron and all
424 # a poor man's implementation could be:
425 # equalities = list(self.equalities)
426 # inequalities = list(self.inequalities)
427 # for other in others:
428 # equalities.extend(other.equalities)
429 # inequalities.extend(other.inequalities)
430 # return self.__class__(equalities, inequalities)
431 raise NotImplementedError
433 def __and__(self
, other
):
434 return self
.intersection(other
)
436 def difference(self
, *others
):
437 # return a new polyhedron with elements in the polyhedron that are not
439 raise NotImplementedError
441 def __sub__(self
, other
):
442 return self
.difference(other
)
446 for constraint
in self
.equalities
:
447 constraints
.append('{} == 0'.format(constraint
))
448 for constraint
in self
.inequalities
:
449 constraints
.append('{} <= 0'.format(constraint
))
450 return '{{{}}}'.format(', '.join(constraints
))
453 equalities
= list(self
.equalities
)
454 inequalities
= list(self
.inequalities
)
455 return '{}(equalities={!r}, inequalities={!r})' \
456 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
459 def fromstring(cls
, string
):
460 raise NotImplementedError
465 universe
= Polyhedron()