ccd1564fb21cb865e226b83e0e6e1cdb5e60d24a
6 from collections
import OrderedDict
, defaultdict
7 from fractions
import Fraction
, gcd
17 def _polymorphic(func
):
18 @functools.wraps(func
)
19 def wrapper(left
, right
):
20 if isinstance(right
, Expression
):
21 return func(left
, right
)
22 elif isinstance(right
, numbers
.Rational
):
23 right
= Constant(right
)
24 return func(left
, right
)
31 This class implements linear expressions.
41 def __new__(cls
, coefficients
=None, constant
=0):
42 if isinstance(coefficients
, str):
44 raise TypeError('too many arguments')
45 return Expression
.fromstring(coefficients
)
46 if coefficients
is None:
47 return Constant(constant
)
48 if isinstance(coefficients
, dict):
49 coefficients
= coefficients
.items()
50 for symbol
, coefficient
in coefficients
:
51 if not isinstance(symbol
, Symbol
):
52 raise TypeError('symbols must be Symbol instances')
53 coefficients
= [(symbol
, coefficient
)
54 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
55 if len(coefficients
) == 0:
56 return Constant(constant
)
57 if len(coefficients
) == 1 and constant
== 0:
58 symbol
, coefficient
= coefficients
[0]
61 self
= object().__new
__(cls
)
62 self
._coefficients
= OrderedDict()
63 for symbol
, coefficient
in sorted(coefficients
,
64 key
=lambda item
: item
[0].name
):
65 if isinstance(coefficient
, Constant
):
66 coefficient
= coefficient
.constant
67 if not isinstance(coefficient
, numbers
.Rational
):
68 raise TypeError('coefficients must be rational numbers '
69 'or Constant instances')
70 self
._coefficients
[symbol
] = coefficient
71 if isinstance(constant
, Constant
):
72 constant
= constant
.constant
73 if not isinstance(constant
, numbers
.Rational
):
74 raise TypeError('constant must be a rational number '
75 'or a Constant instance')
76 self
._constant
= constant
77 self
._symbols
= tuple(self
._coefficients
)
78 self
._dimension
= len(self
._symbols
)
81 def coefficient(self
, symbol
):
82 if not isinstance(symbol
, Symbol
):
83 raise TypeError('symbol must be a Symbol instance')
85 return self
._coefficients
[symbol
]
89 __getitem__
= coefficient
91 def coefficients(self
):
92 yield from self
._coefficients
.items()
104 return self
._dimension
107 return hash((tuple(self
._coefficients
.items()), self
._constant
))
109 def isconstant(self
):
116 yield from self
._coefficients
.values()
129 def __add__(self
, other
):
130 coefficients
= defaultdict(Constant
, self
.coefficients())
131 for symbol
, coefficient
in other
.coefficients():
132 coefficients
[symbol
] += coefficient
133 constant
= self
.constant
+ other
.constant
134 return Expression(coefficients
, constant
)
139 def __sub__(self
, other
):
140 coefficients
= defaultdict(Constant
, self
.coefficients())
141 for symbol
, coefficient
in other
.coefficients():
142 coefficients
[symbol
] -= coefficient
143 constant
= self
.constant
- other
.constant
144 return Expression(coefficients
, constant
)
146 def __rsub__(self
, other
):
147 return -(self
- other
)
150 def __mul__(self
, other
):
151 if other
.isconstant():
152 coefficients
= dict(self
.coefficients())
153 for symbol
in coefficients
:
154 coefficients
[symbol
] *= other
.constant
155 constant
= self
.constant
* other
.constant
156 return Expression(coefficients
, constant
)
157 if isinstance(other
, Expression
) and not self
.isconstant():
158 raise ValueError('non-linear expression: '
159 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
160 return NotImplemented
165 def __truediv__(self
, other
):
166 if other
.isconstant():
167 coefficients
= dict(self
.coefficients())
168 for symbol
in coefficients
:
169 coefficients
[symbol
] = Constant(coefficients
[symbol
], other
.constant
)
170 constant
= Constant(self
.constant
, other
.constant
)
171 return Expression(coefficients
, constant
)
172 if isinstance(other
, Expression
):
173 raise ValueError('non-linear expression: '
174 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
175 return NotImplemented
177 def __rtruediv__(self
, other
):
178 if isinstance(other
, self
):
179 if self
.isconstant():
180 return Constant(other
, self
.constant
)
182 raise ValueError('non-linear expression: '
183 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
184 return NotImplemented
187 def __eq__(self
, other
):
189 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
190 return isinstance(other
, Expression
) and \
191 self
._coefficients
== other
._coefficients
and \
192 self
.constant
== other
.constant
195 def __le__(self
, other
):
196 from .polyhedra
import Le
197 return Le(self
, other
)
200 def __lt__(self
, other
):
201 from .polyhedra
import Lt
202 return Lt(self
, other
)
205 def __ge__(self
, other
):
206 from .polyhedra
import Ge
207 return Ge(self
, other
)
210 def __gt__(self
, other
):
211 from .polyhedra
import Gt
212 return Gt(self
, other
)
215 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
216 [value
.denominator
for value
in self
.values()])
219 def subs(self
, symbol
, expression
=None):
220 if expression
is None:
221 if isinstance(symbol
, dict):
222 symbol
= symbol
.items()
223 substitutions
= symbol
225 substitutions
= [(symbol
, expression
)]
227 for symbol
, expression
in substitutions
:
228 coefficients
= [(othersymbol
, coefficient
)
229 for othersymbol
, coefficient
in result
.coefficients()
230 if othersymbol
!= symbol
]
231 coefficient
= result
.coefficient(symbol
)
232 constant
= result
.constant
233 result
= Expression(coefficients
, constant
) + coefficient
*expression
237 def _fromast(cls
, node
):
238 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
239 return cls
._fromast
(node
.body
[0])
240 elif isinstance(node
, ast
.Expr
):
241 return cls
._fromast
(node
.value
)
242 elif isinstance(node
, ast
.Name
):
243 return Symbol(node
.id)
244 elif isinstance(node
, ast
.Num
):
245 return Constant(node
.n
)
246 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
247 return -cls
._fromast
(node
.operand
)
248 elif isinstance(node
, ast
.BinOp
):
249 left
= cls
._fromast
(node
.left
)
250 right
= cls
._fromast
(node
.right
)
251 if isinstance(node
.op
, ast
.Add
):
253 elif isinstance(node
.op
, ast
.Sub
):
255 elif isinstance(node
.op
, ast
.Mult
):
257 elif isinstance(node
.op
, ast
.Div
):
259 raise SyntaxError('invalid syntax')
261 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
264 def fromstring(cls
, string
):
265 # add implicit multiplication operators, e.g. '5x' -> '5*x'
266 string
= Expression
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
267 tree
= ast
.parse(string
, 'eval')
268 return cls
._fromast
(tree
)
273 for symbol
in self
.symbols
:
274 coefficient
= self
.coefficient(symbol
)
277 string
+= symbol
.name
279 string
+= ' + {}'.format(symbol
)
280 elif coefficient
== -1:
282 string
+= '-{}'.format(symbol
)
284 string
+= ' - {}'.format(symbol
)
287 string
+= '{}*{}'.format(coefficient
, symbol
)
288 elif coefficient
> 0:
289 string
+= ' + {}*{}'.format(coefficient
, symbol
)
291 assert coefficient
< 0
293 string
+= ' - {}*{}'.format(coefficient
, symbol
)
295 constant
= self
.constant
296 if constant
!= 0 and i
== 0:
297 string
+= '{}'.format(constant
)
299 string
+= ' + {}'.format(constant
)
302 string
+= ' - {}'.format(constant
)
307 def _parenstr(self
, always
=False):
309 if not always
and (self
.isconstant() or self
.issymbol()):
312 return '({})'.format(string
)
315 def fromsympy(cls
, expr
):
319 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
320 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
321 if symbol
== sympy
.S
.One
:
322 constant
= coefficient
323 elif isinstance(symbol
, sympy
.Symbol
):
324 symbol
= Symbol(symbol
.name
)
325 coefficients
.append((symbol
, coefficient
))
327 raise ValueError('non-linear expression: {!r}'.format(expr
))
328 return Expression(coefficients
, constant
)
333 for symbol
, coefficient
in self
.coefficients():
334 term
= coefficient
* sympy
.Symbol(symbol
.name
)
336 expr
+= self
.constant
340 class Symbol(Expression
):
346 def __new__(cls
, name
):
347 if not isinstance(name
, str):
348 raise TypeError('name must be a string')
349 self
= object().__new
__(cls
)
350 self
._name
= name
.strip()
358 return hash(self
._name
)
360 def coefficient(self
, symbol
):
361 if not isinstance(symbol
, Symbol
):
362 raise TypeError('symbol must be a Symbol instance')
368 def coefficients(self
):
389 def __eq__(self
, other
):
390 return isinstance(other
, Symbol
) and self
.name
== other
.name
393 def _fromast(cls
, node
):
394 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
395 return cls
._fromast
(node
.body
[0])
396 elif isinstance(node
, ast
.Expr
):
397 return cls
._fromast
(node
.value
)
398 elif isinstance(node
, ast
.Name
):
399 return Symbol(node
.id)
400 raise SyntaxError('invalid syntax')
403 def fromsympy(cls
, expr
):
405 if isinstance(expr
, sympy
.Symbol
):
406 return Symbol(expr
.name
)
408 raise TypeError('expr must be a sympy.Symbol instance')
412 if isinstance(names
, str):
413 names
= names
.replace(',', ' ').split()
414 return tuple(Symbol(name
) for name
in names
)
417 class Constant(Expression
):
423 def __new__(cls
, numerator
=0, denominator
=None):
424 self
= object().__new
__(cls
)
425 if denominator
is None and isinstance(numerator
, Constant
):
426 self
._constant
= numerator
.constant
428 self
._constant
= Fraction(numerator
, denominator
)
432 return hash(self
.constant
)
434 def coefficient(self
, symbol
):
435 if not isinstance(symbol
, Symbol
):
436 raise TypeError('symbol must be a Symbol instance')
439 def coefficients(self
):
450 def isconstant(self
):
457 def __eq__(self
, other
):
458 return isinstance(other
, Constant
) and self
.constant
== other
.constant
461 return self
.constant
!= 0
464 def fromstring(cls
, string
):
465 if not isinstance(string
, str):
466 raise TypeError('string must be a string instance')
467 return Constant(Fraction(string
))
470 def fromsympy(cls
, expr
):
472 if isinstance(expr
, sympy
.Rational
):
473 return Constant(expr
.p
, expr
.q
)
474 elif isinstance(expr
, numbers
.Rational
):
475 return Constant(expr
)
477 raise TypeError('expr must be a sympy.Rational instance')