b23eea8845db0a425efde0d1578445f5915b0dac
6 from collections
import OrderedDict
, defaultdict
, Mapping
7 from fractions
import Fraction
, gcd
12 'Symbol', 'Dummy', 'symbols',
17 def _polymorphic(func
):
18 @functools.wraps(func
)
19 def wrapper(left
, right
):
20 if isinstance(right
, Expression
):
21 return func(left
, right
)
22 elif isinstance(right
, numbers
.Rational
):
23 right
= Rational(right
)
24 return func(left
, right
)
31 This class implements linear expressions.
41 def __new__(cls
, coefficients
=None, constant
=0):
42 if isinstance(coefficients
, str):
44 raise TypeError('too many arguments')
45 return Expression
.fromstring(coefficients
)
46 if coefficients
is None:
47 return Rational(constant
)
48 if isinstance(coefficients
, Mapping
):
49 coefficients
= coefficients
.items()
50 for symbol
, coefficient
in coefficients
:
51 if not isinstance(symbol
, Symbol
):
52 raise TypeError('symbols must be Symbol instances')
53 coefficients
= [(symbol
, coefficient
)
54 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
55 if len(coefficients
) == 0:
56 return Rational(constant
)
57 if len(coefficients
) == 1 and constant
== 0:
58 symbol
, coefficient
= coefficients
[0]
61 self
= object().__new
__(cls
)
62 self
._coefficients
= OrderedDict()
63 for symbol
, coefficient
in sorted(coefficients
,
64 key
=lambda item
: item
[0].sortkey()):
65 if isinstance(coefficient
, Rational
):
66 coefficient
= coefficient
.constant
67 if not isinstance(coefficient
, numbers
.Rational
):
68 raise TypeError('coefficients must be Rational instances')
69 self
._coefficients
[symbol
] = coefficient
70 if isinstance(constant
, Rational
):
71 constant
= constant
.constant
72 if not isinstance(constant
, numbers
.Rational
):
73 raise TypeError('constant must be a Rational instance')
74 self
._constant
= constant
75 self
._symbols
= tuple(self
._coefficients
)
76 self
._dimension
= len(self
._symbols
)
79 def coefficient(self
, symbol
):
80 if not isinstance(symbol
, Symbol
):
81 raise TypeError('symbol must be a Symbol instance')
83 return self
._coefficients
[symbol
]
87 __getitem__
= coefficient
89 def coefficients(self
):
90 yield from self
._coefficients
.items()
102 return self
._dimension
105 return hash((tuple(self
._coefficients
.items()), self
._constant
))
107 def isconstant(self
):
114 yield from self
._coefficients
.values()
127 def __add__(self
, other
):
128 coefficients
= defaultdict(Rational
, self
.coefficients())
129 for symbol
, coefficient
in other
.coefficients():
130 coefficients
[symbol
] += coefficient
131 constant
= self
.constant
+ other
.constant
132 return Expression(coefficients
, constant
)
137 def __sub__(self
, other
):
138 coefficients
= defaultdict(Rational
, self
.coefficients())
139 for symbol
, coefficient
in other
.coefficients():
140 coefficients
[symbol
] -= coefficient
141 constant
= self
.constant
- other
.constant
142 return Expression(coefficients
, constant
)
144 def __rsub__(self
, other
):
145 return -(self
- other
)
148 def __mul__(self
, other
):
149 if other
.isconstant():
150 coefficients
= dict(self
.coefficients())
151 for symbol
in coefficients
:
152 coefficients
[symbol
] *= other
.constant
153 constant
= self
.constant
* other
.constant
154 return Expression(coefficients
, constant
)
155 if isinstance(other
, Expression
) and not self
.isconstant():
156 raise ValueError('non-linear expression: '
157 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
158 return NotImplemented
163 def __truediv__(self
, other
):
164 if other
.isconstant():
165 coefficients
= dict(self
.coefficients())
166 for symbol
in coefficients
:
167 coefficients
[symbol
] = Rational(coefficients
[symbol
], other
.constant
)
168 constant
= Rational(self
.constant
, other
.constant
)
169 return Expression(coefficients
, constant
)
170 if isinstance(other
, Expression
):
171 raise ValueError('non-linear expression: '
172 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
173 return NotImplemented
175 def __rtruediv__(self
, other
):
176 if isinstance(other
, self
):
177 if self
.isconstant():
178 return Rational(other
, self
.constant
)
180 raise ValueError('non-linear expression: '
181 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
182 return NotImplemented
185 def __eq__(self
, other
):
187 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
188 return isinstance(other
, Expression
) and \
189 self
._coefficients
== other
._coefficients
and \
190 self
.constant
== other
.constant
193 def __le__(self
, other
):
194 from .polyhedra
import Le
195 return Le(self
, other
)
198 def __lt__(self
, other
):
199 from .polyhedra
import Lt
200 return Lt(self
, other
)
203 def __ge__(self
, other
):
204 from .polyhedra
import Ge
205 return Ge(self
, other
)
208 def __gt__(self
, other
):
209 from .polyhedra
import Gt
210 return Gt(self
, other
)
213 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
214 [value
.denominator
for value
in self
.values()])
217 def subs(self
, symbol
, expression
=None):
218 if expression
is None:
219 if isinstance(symbol
, Mapping
):
220 symbol
= symbol
.items()
221 substitutions
= symbol
223 substitutions
= [(symbol
, expression
)]
225 for symbol
, expression
in substitutions
:
226 coefficients
= [(othersymbol
, coefficient
)
227 for othersymbol
, coefficient
in result
.coefficients()
228 if othersymbol
!= symbol
]
229 coefficient
= result
.coefficient(symbol
)
230 constant
= result
.constant
231 result
= Expression(coefficients
, constant
) + coefficient
*expression
235 def _fromast(cls
, node
):
236 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
237 return cls
._fromast
(node
.body
[0])
238 elif isinstance(node
, ast
.Expr
):
239 return cls
._fromast
(node
.value
)
240 elif isinstance(node
, ast
.Name
):
241 return Symbol(node
.id)
242 elif isinstance(node
, ast
.Num
):
243 return Rational(node
.n
)
244 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
245 return -cls
._fromast
(node
.operand
)
246 elif isinstance(node
, ast
.BinOp
):
247 left
= cls
._fromast
(node
.left
)
248 right
= cls
._fromast
(node
.right
)
249 if isinstance(node
.op
, ast
.Add
):
251 elif isinstance(node
.op
, ast
.Sub
):
253 elif isinstance(node
.op
, ast
.Mult
):
255 elif isinstance(node
.op
, ast
.Div
):
257 raise SyntaxError('invalid syntax')
259 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
262 def fromstring(cls
, string
):
263 # add implicit multiplication operators, e.g. '5x' -> '5*x'
264 string
= Expression
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
265 tree
= ast
.parse(string
, 'eval')
266 return cls
._fromast
(tree
)
270 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
272 string
+= '' if i
== 0 else ' + '
273 string
+= '{!r}'.format(symbol
)
274 elif coefficient
== -1:
275 string
+= '-' if i
== 0 else ' - '
276 string
+= '{!r}'.format(symbol
)
279 string
+= '{}*{!r}'.format(coefficient
, symbol
)
280 elif coefficient
> 0:
281 string
+= ' + {}*{!r}'.format(coefficient
, symbol
)
283 string
+= ' - {}*{!r}'.format(-coefficient
, symbol
)
284 constant
= self
.constant
286 string
+= '{}'.format(constant
)
288 string
+= ' + {}'.format(constant
)
290 string
+= ' - {}'.format(-constant
)
293 def _parenstr(self
, always
=False):
295 if not always
and (self
.isconstant() or self
.issymbol()):
298 return '({})'.format(string
)
301 def fromsympy(cls
, expr
):
305 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
306 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
307 if symbol
== sympy
.S
.One
:
308 constant
= coefficient
309 elif isinstance(symbol
, sympy
.Symbol
):
310 symbol
= Symbol(symbol
.name
)
311 coefficients
.append((symbol
, coefficient
))
313 raise ValueError('non-linear expression: {!r}'.format(expr
))
314 return Expression(coefficients
, constant
)
319 for symbol
, coefficient
in self
.coefficients():
320 term
= coefficient
* sympy
.Symbol(symbol
.name
)
322 expr
+= self
.constant
326 class Symbol(Expression
):
332 def __new__(cls
, name
):
333 if not isinstance(name
, str):
334 raise TypeError('name must be a string')
335 self
= object().__new
__(cls
)
336 self
._name
= name
.strip()
344 return hash(self
.sortkey())
346 def coefficient(self
, symbol
):
347 if not isinstance(symbol
, Symbol
):
348 raise TypeError('symbol must be a Symbol instance')
354 def coefficients(self
):
378 def __eq__(self
, other
):
379 return not isinstance(other
, Dummy
) and isinstance(other
, Symbol
) \
380 and self
.name
== other
.name
383 return Dummy(self
.name
)
386 def _fromast(cls
, node
):
387 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
388 return cls
._fromast
(node
.body
[0])
389 elif isinstance(node
, ast
.Expr
):
390 return cls
._fromast
(node
.value
)
391 elif isinstance(node
, ast
.Name
):
392 return Symbol(node
.id)
393 raise SyntaxError('invalid syntax')
399 def fromsympy(cls
, expr
):
401 if isinstance(expr
, sympy
.Symbol
):
402 return cls(expr
.name
)
404 raise TypeError('expr must be a sympy.Symbol instance')
416 def __new__(cls
, name
=None):
418 name
= 'Dummy_{}'.format(Dummy
._count
)
419 self
= object().__new
__(cls
)
420 self
._name
= name
.strip()
421 self
._index
= Dummy
._count
426 return hash(self
.sortkey())
429 return self
._name
, self
._index
431 def __eq__(self
, other
):
432 return isinstance(other
, Dummy
) and self
._index
== other
._index
435 return '_{}'.format(self
.name
)
439 if isinstance(names
, str):
440 names
= names
.replace(',', ' ').split()
441 return tuple(Symbol(name
) for name
in names
)
444 class Rational(Expression
):
450 def __new__(cls
, numerator
=0, denominator
=None):
451 self
= object().__new
__(cls
)
452 if denominator
is None and isinstance(numerator
, Rational
):
453 self
._constant
= numerator
.constant
455 self
._constant
= Fraction(numerator
, denominator
)
459 return hash(self
.constant
)
461 def coefficient(self
, symbol
):
462 if not isinstance(symbol
, Symbol
):
463 raise TypeError('symbol must be a Symbol instance')
466 def coefficients(self
):
477 def isconstant(self
):
484 def __eq__(self
, other
):
485 return isinstance(other
, Rational
) and self
.constant
== other
.constant
488 return self
.constant
!= 0
491 def fromstring(cls
, string
):
492 if not isinstance(string
, str):
493 raise TypeError('string must be a string instance')
494 return Rational(Fraction(string
))
497 def fromsympy(cls
, expr
):
499 if isinstance(expr
, sympy
.Rational
):
500 return Rational(expr
.p
, expr
.q
)
501 elif isinstance(expr
, numbers
.Rational
):
502 return Rational(expr
)
504 raise TypeError('expr must be a sympy.Rational instance')