07d40052e2d240aefba921e84630d7e64a8c1dd0
6 from collections
import OrderedDict
, defaultdict
, Mapping
7 from fractions
import Fraction
, gcd
12 'Symbol', 'Dummy', 'symbols',
17 def _polymorphic(func
):
18 @functools.wraps(func
)
19 def wrapper(left
, right
):
20 if isinstance(right
, Expression
):
21 return func(left
, right
)
22 elif isinstance(right
, numbers
.Rational
):
23 right
= Rational(right
)
24 return func(left
, right
)
31 This class implements linear expressions.
34 def __new__(cls
, coefficients
=None, constant
=0):
35 if isinstance(coefficients
, str):
37 raise TypeError('too many arguments')
38 return Expression
.fromstring(coefficients
)
39 if coefficients
is None:
40 return Rational(constant
)
41 if isinstance(coefficients
, Mapping
):
42 coefficients
= coefficients
.items()
43 coefficients
= list(coefficients
)
44 for symbol
, coefficient
in coefficients
:
45 if not isinstance(symbol
, Symbol
):
46 raise TypeError('symbols must be Symbol instances')
47 if not isinstance(coefficient
, numbers
.Rational
):
48 raise TypeError('coefficients must be rational numbers')
49 if not isinstance(constant
, numbers
.Rational
):
50 raise TypeError('constant must be a rational number')
51 if len(coefficients
) == 0:
52 return Rational(constant
)
53 if len(coefficients
) == 1 and constant
== 0:
54 symbol
, coefficient
= coefficients
[0]
57 coefficients
= [(symbol
, Fraction(coefficient
))
58 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
59 coefficients
.sort(key
=lambda item
: item
[0].sortkey())
60 self
= object().__new
__(cls
)
61 self
._coefficients
= OrderedDict(coefficients
)
62 self
._constant
= Fraction(constant
)
63 self
._symbols
= tuple(self
._coefficients
)
64 self
._dimension
= len(self
._symbols
)
67 def coefficient(self
, symbol
):
68 if not isinstance(symbol
, Symbol
):
69 raise TypeError('symbol must be a Symbol instance')
70 return Rational(self
._coefficients
.get(symbol
, 0))
72 __getitem__
= coefficient
74 def coefficients(self
):
75 for symbol
, coefficient
in self
._coefficients
.items():
76 yield symbol
, Rational(coefficient
)
80 return Rational(self
._constant
)
88 return self
._dimension
91 return hash((tuple(self
._coefficients
.items()), self
._constant
))
100 for coefficient
in self
._coefficients
.values():
101 yield Rational(coefficient
)
102 yield Rational(self
._constant
)
114 def __add__(self
, other
):
115 coefficients
= defaultdict(Fraction
, self
._coefficients
)
116 for symbol
, coefficient
in other
._coefficients
.items():
117 coefficients
[symbol
] += coefficient
118 constant
= self
._constant
+ other
._constant
119 return Expression(coefficients
, constant
)
124 def __sub__(self
, other
):
125 coefficients
= defaultdict(Fraction
, self
._coefficients
)
126 for symbol
, coefficient
in other
._coefficients
.items():
127 coefficients
[symbol
] -= coefficient
128 constant
= self
._constant
- other
._constant
129 return Expression(coefficients
, constant
)
132 def __rsub__(self
, other
):
135 def __mul__(self
, other
):
136 if isinstance(other
, numbers
.Rational
):
137 coefficients
= ((symbol
, coefficient
* other
)
138 for symbol
, coefficient
in self
._coefficients
.items())
139 constant
= self
._constant
* other
140 return Expression(coefficients
, constant
)
141 return NotImplemented
145 def __truediv__(self
, other
):
146 if isinstance(other
, numbers
.Rational
):
147 coefficients
= ((symbol
, coefficient
/ other
)
148 for symbol
, coefficient
in self
._coefficients
.items())
149 constant
= self
._constant
/ other
150 return Expression(coefficients
, constant
)
151 return NotImplemented
154 def __eq__(self
, other
):
155 # returns a boolean, not a constraint
156 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
157 return isinstance(other
, Expression
) and \
158 self
._coefficients
== other
._coefficients
and \
159 self
._constant
== other
._constant
161 def __le__(self
, other
):
162 from .polyhedra
import Le
163 return Le(self
, other
)
165 def __lt__(self
, other
):
166 from .polyhedra
import Lt
167 return Lt(self
, other
)
169 def __ge__(self
, other
):
170 from .polyhedra
import Ge
171 return Ge(self
, other
)
173 def __gt__(self
, other
):
174 from .polyhedra
import Gt
175 return Gt(self
, other
)
178 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
179 [value
.denominator
for value
in self
.values()])
182 def subs(self
, symbol
, expression
=None):
183 if expression
is None:
184 if isinstance(symbol
, Mapping
):
185 symbol
= symbol
.items()
186 substitutions
= symbol
188 substitutions
= [(symbol
, expression
)]
190 for symbol
, expression
in substitutions
:
191 if not isinstance(symbol
, Symbol
):
192 raise TypeError('symbols must be Symbol instances')
193 coefficients
= [(othersymbol
, coefficient
)
194 for othersymbol
, coefficient
in result
._coefficients
.items()
195 if othersymbol
!= symbol
]
196 coefficient
= result
._coefficients
.get(symbol
, 0)
197 constant
= result
._constant
198 result
= Expression(coefficients
, constant
) + coefficient
*expression
202 def _fromast(cls
, node
):
203 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
204 return cls
._fromast
(node
.body
[0])
205 elif isinstance(node
, ast
.Expr
):
206 return cls
._fromast
(node
.value
)
207 elif isinstance(node
, ast
.Name
):
208 return Symbol(node
.id)
209 elif isinstance(node
, ast
.Num
):
210 return Rational(node
.n
)
211 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
212 return -cls
._fromast
(node
.operand
)
213 elif isinstance(node
, ast
.BinOp
):
214 left
= cls
._fromast
(node
.left
)
215 right
= cls
._fromast
(node
.right
)
216 if isinstance(node
.op
, ast
.Add
):
218 elif isinstance(node
.op
, ast
.Sub
):
220 elif isinstance(node
.op
, ast
.Mult
):
222 elif isinstance(node
.op
, ast
.Div
):
224 raise SyntaxError('invalid syntax')
226 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
229 def fromstring(cls
, string
):
230 # add implicit multiplication operators, e.g. '5x' -> '5*x'
231 string
= Expression
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
232 tree
= ast
.parse(string
, 'eval')
233 return cls
._fromast
(tree
)
237 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
241 elif coefficient
== -1:
242 string
+= '-' if i
== 0 else ' - '
244 string
+= '{}*'.format(coefficient
)
245 elif coefficient
> 0:
246 string
+= ' + {}*'.format(coefficient
)
248 string
+= ' - {}*'.format(-coefficient
)
249 string
+= '{}'.format(symbol
)
250 constant
= self
.constant
252 string
+= '{}'.format(constant
)
254 string
+= ' + {}'.format(constant
)
256 string
+= ' - {}'.format(-constant
)
259 def _repr_latex_(self
):
261 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
265 elif coefficient
== -1:
266 string
+= '-' if i
== 0 else ' - '
268 string
+= '{}'.format(coefficient
._repr
_latex
_().strip('$'))
269 elif coefficient
> 0:
270 string
+= ' + {}'.format(coefficient
._repr
_latex
_().strip('$'))
271 elif coefficient
< 0:
272 string
+= ' - {}'.format((-coefficient
)._repr
_latex
_().strip('$'))
273 string
+= '{}'.format(symbol
._repr
_latex
_().strip('$'))
274 constant
= self
.constant
276 string
+= '{}'.format(constant
._repr
_latex
_().strip('$'))
278 string
+= ' + {}'.format(constant
._repr
_latex
_().strip('$'))
280 string
+= ' - {}'.format((-constant
)._repr
_latex
_().strip('$'))
281 return '$${}$$'.format(string
)
283 def _parenstr(self
, always
=False):
285 if not always
and (self
.isconstant() or self
.issymbol()):
288 return '({})'.format(string
)
291 def fromsympy(cls
, expr
):
295 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
296 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
297 if symbol
== sympy
.S
.One
:
298 constant
= coefficient
299 elif isinstance(symbol
, sympy
.Symbol
):
300 symbol
= Symbol(symbol
.name
)
301 coefficients
.append((symbol
, coefficient
))
303 raise ValueError('non-linear expression: {!r}'.format(expr
))
304 return Expression(coefficients
, constant
)
309 for symbol
, coefficient
in self
.coefficients():
310 term
= coefficient
* sympy
.Symbol(symbol
.name
)
312 expr
+= self
.constant
316 class Symbol(Expression
):
318 def __new__(cls
, name
):
319 if not isinstance(name
, str):
320 raise TypeError('name must be a string')
321 self
= object().__new
__(cls
)
322 self
._name
= name
.strip()
323 self
._coefficients
= {self
: Fraction(1)}
324 self
._constant
= Fraction(0)
325 self
._symbols
= (self
,)
334 return hash(self
.sortkey())
342 def __eq__(self
, other
):
343 return self
.sortkey() == other
.sortkey()
346 return Dummy(self
.name
)
349 def _fromast(cls
, node
):
350 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
351 return cls
._fromast
(node
.body
[0])
352 elif isinstance(node
, ast
.Expr
):
353 return cls
._fromast
(node
.value
)
354 elif isinstance(node
, ast
.Name
):
355 return Symbol(node
.id)
356 raise SyntaxError('invalid syntax')
361 def _repr_latex_(self
):
362 return '$${}$$'.format(self
.name
)
365 def fromsympy(cls
, expr
):
367 if isinstance(expr
, sympy
.Dummy
):
368 return Dummy(expr
.name
)
369 elif isinstance(expr
, sympy
.Symbol
):
370 return Symbol(expr
.name
)
372 raise TypeError('expr must be a sympy.Symbol instance')
379 def __new__(cls
, name
=None):
381 name
= 'Dummy_{}'.format(Dummy
._count
)
382 elif not isinstance(name
, str):
383 raise TypeError('name must be a string')
384 self
= object().__new
__(cls
)
385 self
._index
= Dummy
._count
386 self
._name
= name
.strip()
387 self
._coefficients
= {self
: Fraction(1)}
388 self
._constant
= Fraction(0)
389 self
._symbols
= (self
,)
395 return hash(self
.sortkey())
398 return self
._name
, self
._index
401 return '_{}'.format(self
.name
)
403 def _repr_latex_(self
):
404 return '$${}_{{{}}}$$'.format(self
.name
, self
._index
)
408 if isinstance(names
, str):
409 names
= names
.replace(',', ' ').split()
410 return tuple(Symbol(name
) for name
in names
)
413 class Rational(Expression
, Fraction
):
415 def __new__(cls
, numerator
=0, denominator
=None):
416 self
= object().__new
__(cls
)
417 self
._coefficients
= {}
418 self
._constant
= Fraction(numerator
, denominator
)
421 self
._numerator
= self
._constant
.numerator
422 self
._denominator
= self
._constant
.denominator
426 return Fraction
.__hash
__(self
)
432 def isconstant(self
):
436 return Fraction
.__bool
__(self
)
439 if self
.denominator
== 1:
440 return '{!r}'.format(self
.numerator
)
442 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)
444 def _repr_latex_(self
):
445 if self
.denominator
== 1:
446 return '$${}$$'.format(self
.numerator
)
447 elif self
.numerator
< 0:
448 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self
.numerator
,
451 return '$$\\frac{{{}}}{{{}}}$$'.format(self
.numerator
,
455 def fromsympy(cls
, expr
):
457 if isinstance(expr
, sympy
.Rational
):
458 return Rational(expr
.p
, expr
.q
)
459 elif isinstance(expr
, numbers
.Rational
):
460 return Rational(expr
)
462 raise TypeError('expr must be a sympy.Rational instance')