6 from collections
import OrderedDict
, defaultdict
, Mapping
7 from fractions
import Fraction
, gcd
12 'Symbol', 'Dummy', 'symbols',
17 def _polymorphic(func
):
18 @functools.wraps(func
)
19 def wrapper(left
, right
):
20 if isinstance(right
, Expression
):
21 return func(left
, right
)
22 elif isinstance(right
, numbers
.Rational
):
23 right
= Rational(right
)
24 return func(left
, right
)
31 This class implements linear expressions.
34 def __new__(cls
, coefficients
=None, constant
=0):
35 if isinstance(coefficients
, str):
37 raise TypeError('too many arguments')
38 return Expression
.fromstring(coefficients
)
39 if coefficients
is None:
40 return Rational(constant
)
41 if isinstance(coefficients
, Mapping
):
42 coefficients
= coefficients
.items()
43 for symbol
, coefficient
in coefficients
:
44 if not isinstance(symbol
, Symbol
):
45 raise TypeError('symbols must be Symbol instances')
46 if not isinstance(coefficient
, numbers
.Rational
):
47 raise TypeError('coefficients must be rational numbers')
48 coefficients
= [(symbol
, Fraction(coefficient
))
49 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
50 if not isinstance(constant
, numbers
.Rational
):
51 raise TypeError('constant must be a rational number')
52 constant
= Fraction(constant
)
53 if len(coefficients
) == 0:
54 return Rational(constant
)
55 if len(coefficients
) == 1 and constant
== 0:
56 symbol
, coefficient
= coefficients
[0]
59 self
= object().__new
__(cls
)
60 self
._coefficients
= OrderedDict(sorted(coefficients
,
61 key
=lambda item
: item
[0].sortkey()))
62 self
._constant
= constant
63 self
._symbols
= tuple(self
._coefficients
)
64 self
._dimension
= len(self
._symbols
)
67 def coefficient(self
, symbol
):
68 if not isinstance(symbol
, Symbol
):
69 raise TypeError('symbol must be a Symbol instance')
71 return Rational(self
._coefficients
[symbol
])
75 __getitem__
= coefficient
77 def coefficients(self
):
78 for symbol
, coefficient
in self
._coefficients
.items():
79 yield symbol
, Rational(coefficient
)
83 return Rational(self
._constant
)
91 return self
._dimension
94 return hash((tuple(self
._coefficients
.items()), self
._constant
))
103 for coefficient
in self
._coefficients
.values():
104 yield Rational(coefficient
)
105 yield Rational(self
._constant
)
117 def __add__(self
, other
):
118 coefficients
= defaultdict(Fraction
, self
._coefficients
)
119 for symbol
, coefficient
in other
._coefficients
.items():
120 coefficients
[symbol
] += coefficient
121 constant
= self
._constant
+ other
._constant
122 return Expression(coefficients
, constant
)
127 def __sub__(self
, other
):
128 coefficients
= defaultdict(Fraction
, self
._coefficients
)
129 for symbol
, coefficient
in other
._coefficients
.items():
130 coefficients
[symbol
] -= coefficient
131 constant
= self
._constant
- other
._constant
132 return Expression(coefficients
, constant
)
134 def __rsub__(self
, other
):
135 return -(self
- other
)
138 def __mul__(self
, other
):
139 if isinstance(other
, Rational
):
140 return other
.__rmul
__(self
)
141 return NotImplemented
146 def __truediv__(self
, other
):
147 if isinstance(other
, Rational
):
148 return other
.__rtruediv
__(self
)
149 return NotImplemented
151 __rtruediv__
= __truediv__
154 def __eq__(self
, other
):
156 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
157 return isinstance(other
, Expression
) and \
158 self
._coefficients
== other
._coefficients
and \
159 self
._constant
== other
._constant
162 def __le__(self
, other
):
163 from .polyhedra
import Le
164 return Le(self
, other
)
167 def __lt__(self
, other
):
168 from .polyhedra
import Lt
169 return Lt(self
, other
)
172 def __ge__(self
, other
):
173 from .polyhedra
import Ge
174 return Ge(self
, other
)
177 def __gt__(self
, other
):
178 from .polyhedra
import Gt
179 return Gt(self
, other
)
182 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
183 [value
.denominator
for value
in self
.values()])
186 def subs(self
, symbol
, expression
=None):
187 if expression
is None:
188 if isinstance(symbol
, Mapping
):
189 symbol
= symbol
.items()
190 substitutions
= symbol
192 substitutions
= [(symbol
, expression
)]
194 for symbol
, expression
in substitutions
:
195 if not isinstance(symbol
, Symbol
):
196 raise TypeError('symbols must be Symbol instances')
197 coefficients
= [(othersymbol
, coefficient
)
198 for othersymbol
, coefficient
in result
._coefficients
.items()
199 if othersymbol
!= symbol
]
200 coefficient
= result
._coefficients
.get(symbol
, 0)
201 constant
= result
._constant
202 result
= Expression(coefficients
, constant
) + coefficient
*expression
206 def _fromast(cls
, node
):
207 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
208 return cls
._fromast
(node
.body
[0])
209 elif isinstance(node
, ast
.Expr
):
210 return cls
._fromast
(node
.value
)
211 elif isinstance(node
, ast
.Name
):
212 return Symbol(node
.id)
213 elif isinstance(node
, ast
.Num
):
214 return Rational(node
.n
)
215 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
216 return -cls
._fromast
(node
.operand
)
217 elif isinstance(node
, ast
.BinOp
):
218 left
= cls
._fromast
(node
.left
)
219 right
= cls
._fromast
(node
.right
)
220 if isinstance(node
.op
, ast
.Add
):
222 elif isinstance(node
.op
, ast
.Sub
):
224 elif isinstance(node
.op
, ast
.Mult
):
226 elif isinstance(node
.op
, ast
.Div
):
228 raise SyntaxError('invalid syntax')
230 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
233 def fromstring(cls
, string
):
234 # add implicit multiplication operators, e.g. '5x' -> '5*x'
235 string
= Expression
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
236 tree
= ast
.parse(string
, 'eval')
237 return cls
._fromast
(tree
)
241 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
245 elif coefficient
== -1:
246 string
+= '-' if i
== 0 else ' - '
248 string
+= '{}*'.format(coefficient
)
249 elif coefficient
> 0:
250 string
+= ' + {}*'.format(coefficient
)
252 string
+= ' - {}*'.format(-coefficient
)
253 string
+= '{}'.format(symbol
)
254 constant
= self
.constant
256 string
+= '{}'.format(constant
)
258 string
+= ' + {}'.format(constant
)
260 string
+= ' - {}'.format(-constant
)
263 def _repr_latex_(self
):
265 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
269 elif coefficient
== -1:
270 string
+= '-' if i
== 0 else ' - '
272 string
+= '{}'.format(coefficient
._repr
_latex
_().strip('$'))
273 elif coefficient
> 0:
274 string
+= ' + {}'.format(coefficient
._repr
_latex
_().strip('$'))
275 elif coefficient
< 0:
276 string
+= ' - {}'.format((-coefficient
)._repr
_latex
_().strip('$'))
277 string
+= '{}'.format(symbol
._repr
_latex
_().strip('$'))
278 constant
= self
.constant
280 string
+= '{}'.format(constant
._repr
_latex
_().strip('$'))
282 string
+= ' + {}'.format(constant
._repr
_latex
_().strip('$'))
284 string
+= ' - {}'.format((-constant
)._repr
_latex
_().strip('$'))
285 return '${}$'.format(string
)
287 def _parenstr(self
, always
=False):
289 if not always
and (self
.isconstant() or self
.issymbol()):
292 return '({})'.format(string
)
295 def fromsympy(cls
, expr
):
299 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
300 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
301 if symbol
== sympy
.S
.One
:
302 constant
= coefficient
303 elif isinstance(symbol
, sympy
.Symbol
):
304 symbol
= Symbol(symbol
.name
)
305 coefficients
.append((symbol
, coefficient
))
307 raise ValueError('non-linear expression: {!r}'.format(expr
))
308 return Expression(coefficients
, constant
)
313 for symbol
, coefficient
in self
.coefficients():
314 term
= coefficient
* sympy
.Symbol(symbol
.name
)
316 expr
+= self
.constant
320 class Symbol(Expression
):
322 def __new__(cls
, name
):
323 if not isinstance(name
, str):
324 raise TypeError('name must be a string')
325 self
= object().__new
__(cls
)
326 self
._name
= name
.strip()
327 self
._coefficients
= {self
: 1}
329 self
._symbols
= (self
,)
338 return hash(self
.sortkey())
346 def __eq__(self
, other
):
347 return not isinstance(other
, Dummy
) and isinstance(other
, Symbol
) \
348 and self
.name
== other
.name
351 return Dummy(self
.name
)
354 def _fromast(cls
, node
):
355 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
356 return cls
._fromast
(node
.body
[0])
357 elif isinstance(node
, ast
.Expr
):
358 return cls
._fromast
(node
.value
)
359 elif isinstance(node
, ast
.Name
):
360 return Symbol(node
.id)
361 raise SyntaxError('invalid syntax')
366 def _repr_latex_(self
):
367 return '${}$'.format(self
.name
)
370 def fromsympy(cls
, expr
):
372 if isinstance(expr
, sympy
.Symbol
):
373 return cls(expr
.name
)
375 raise TypeError('expr must be a sympy.Symbol instance')
382 def __new__(cls
, name
=None):
384 name
= 'Dummy_{}'.format(Dummy
._count
)
385 self
= object().__new
__(cls
)
386 self
._index
= Dummy
._count
387 self
._name
= name
.strip()
388 self
._coefficients
= {self
: 1}
390 self
._symbols
= (self
,)
396 return hash(self
.sortkey())
399 return self
._name
, self
._index
401 def __eq__(self
, other
):
402 return isinstance(other
, Dummy
) and self
._index
== other
._index
405 return '_{}'.format(self
.name
)
407 def _repr_latex_(self
):
408 return '${}_{{{}}}$'.format(self
.name
, self
._index
)
412 if isinstance(names
, str):
413 names
= names
.replace(',', ' ').split()
414 return tuple(Symbol(name
) for name
in names
)
417 class Rational(Expression
, Fraction
):
419 def __new__(cls
, numerator
=0, denominator
=None):
420 self
= Fraction
.__new
__(cls
, numerator
, denominator
)
421 self
._coefficients
= {}
422 self
._constant
= Fraction(self
)
428 return Fraction
.__hash
__(self
)
434 def isconstant(self
):
438 return Fraction
.__bool
__(self
)
441 def __mul__(self
, other
):
442 coefficients
= dict(other
._coefficients
)
443 for symbol
in coefficients
:
444 coefficients
[symbol
] *= self
._constant
445 constant
= other
._constant
* self
._constant
446 return Expression(coefficients
, constant
)
451 def __rtruediv__(self
, other
):
452 coefficients
= dict(other
._coefficients
)
453 for symbol
in coefficients
:
454 coefficients
[symbol
] /= self
._constant
455 constant
= other
._constant
/ self
._constant
456 return Expression(coefficients
, constant
)
459 def fromstring(cls
, string
):
460 if not isinstance(string
, str):
461 raise TypeError('string must be a string instance')
462 return Rational(string
)
465 if self
.denominator
== 1:
466 return '{!r}'.format(self
.numerator
)
468 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)
470 def _repr_latex_(self
):
471 if self
.denominator
== 1:
472 return '${}$'.format(self
.numerator
)
473 elif self
.numerator
< 0:
474 return '$-\\frac{{{}}}{{{}}}$'.format(-self
.numerator
,
477 return '$\\frac{{{}}}{{{}}}$'.format(self
.numerator
,
481 def fromsympy(cls
, expr
):
483 if isinstance(expr
, sympy
.Rational
):
484 return Rational(expr
.p
, expr
.q
)
485 elif isinstance(expr
, numbers
.Rational
):
486 return Rational(expr
)
488 raise TypeError('expr must be a sympy.Rational instance')