6 from collections
import OrderedDict
, defaultdict
, Mapping
7 from fractions
import Fraction
, gcd
12 'Symbol', 'Dummy', 'symbols',
17 def _polymorphic(func
):
18 @functools.wraps(func
)
19 def wrapper(left
, right
):
20 if isinstance(right
, Expression
):
21 return func(left
, right
)
22 elif isinstance(right
, numbers
.Rational
):
23 right
= Rational(right
)
24 return func(left
, right
)
31 This class implements linear expressions.
34 def __new__(cls
, coefficients
=None, constant
=0):
35 if isinstance(coefficients
, str):
37 raise TypeError('too many arguments')
38 return Expression
.fromstring(coefficients
)
39 if coefficients
is None:
40 return Rational(constant
)
41 if isinstance(coefficients
, Mapping
):
42 coefficients
= coefficients
.items()
43 for symbol
, coefficient
in coefficients
:
44 if not isinstance(symbol
, Symbol
):
45 raise TypeError('symbols must be Symbol instances')
46 if not isinstance(coefficient
, numbers
.Rational
):
47 raise TypeError('coefficients must be rational numbers')
48 coefficients
= [(symbol
, Fraction(coefficient
))
49 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
50 if not isinstance(constant
, numbers
.Rational
):
51 raise TypeError('constant must be a rational number')
52 constant
= Fraction(constant
)
53 if len(coefficients
) == 0:
54 return Rational(constant
)
55 if len(coefficients
) == 1 and constant
== 0:
56 symbol
, coefficient
= coefficients
[0]
59 self
= object().__new
__(cls
)
60 self
._coefficients
= OrderedDict(sorted(coefficients
,
61 key
=lambda item
: item
[0].sortkey()))
62 self
._constant
= constant
63 self
._symbols
= tuple(self
._coefficients
)
64 self
._dimension
= len(self
._symbols
)
67 def coefficient(self
, symbol
):
68 if not isinstance(symbol
, Symbol
):
69 raise TypeError('symbol must be a Symbol instance')
71 return Rational(self
._coefficients
[symbol
])
75 __getitem__
= coefficient
77 def coefficients(self
):
78 for symbol
, coefficient
in self
._coefficients
.items():
79 yield symbol
, Rational(coefficient
)
83 return Rational(self
._constant
)
91 return self
._dimension
94 return hash((tuple(self
._coefficients
.items()), self
._constant
))
103 for coefficient
in self
._coefficients
.values():
104 yield Rational(coefficient
)
105 yield Rational(self
._constant
)
117 def __add__(self
, other
):
118 coefficients
= defaultdict(Fraction
, self
._coefficients
)
119 for symbol
, coefficient
in other
._coefficients
.items():
120 coefficients
[symbol
] += coefficient
121 constant
= self
._constant
+ other
._constant
122 return Expression(coefficients
, constant
)
127 def __sub__(self
, other
):
128 coefficients
= defaultdict(Fraction
, self
._coefficients
)
129 for symbol
, coefficient
in other
._coefficients
.items():
130 coefficients
[symbol
] -= coefficient
131 constant
= self
._constant
- other
._constant
132 return Expression(coefficients
, constant
)
134 def __rsub__(self
, other
):
135 return -(self
- other
)
137 def __mul__(self
, other
):
138 if isinstance(other
, numbers
.Rational
):
139 coefficients
= dict(self
._coefficients
)
140 for symbol
in coefficients
:
141 coefficients
[symbol
] *= other
142 constant
= self
._constant
* other
143 return Expression(coefficients
, constant
)
144 return NotImplemented
148 def __truediv__(self
, other
):
149 if isinstance(other
, numbers
.Rational
):
150 coefficients
= dict(self
._coefficients
)
151 for symbol
in coefficients
:
152 coefficients
[symbol
] /= other
153 constant
= self
._constant
/ other
154 # import pdb; pdb.set_trace()
155 return Expression(coefficients
, constant
)
156 return NotImplemented
159 def __eq__(self
, other
):
161 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
162 return isinstance(other
, Expression
) and \
163 self
._coefficients
== other
._coefficients
and \
164 self
._constant
== other
._constant
167 def __le__(self
, other
):
168 from .polyhedra
import Le
169 return Le(self
, other
)
172 def __lt__(self
, other
):
173 from .polyhedra
import Lt
174 return Lt(self
, other
)
177 def __ge__(self
, other
):
178 from .polyhedra
import Ge
179 return Ge(self
, other
)
182 def __gt__(self
, other
):
183 from .polyhedra
import Gt
184 return Gt(self
, other
)
187 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
188 [value
.denominator
for value
in self
.values()])
191 def subs(self
, symbol
, expression
=None):
192 if expression
is None:
193 if isinstance(symbol
, Mapping
):
194 symbol
= symbol
.items()
195 substitutions
= symbol
197 substitutions
= [(symbol
, expression
)]
199 for symbol
, expression
in substitutions
:
200 if not isinstance(symbol
, Symbol
):
201 raise TypeError('symbols must be Symbol instances')
202 coefficients
= [(othersymbol
, coefficient
)
203 for othersymbol
, coefficient
in result
._coefficients
.items()
204 if othersymbol
!= symbol
]
205 coefficient
= result
._coefficients
.get(symbol
, 0)
206 constant
= result
._constant
207 result
= Expression(coefficients
, constant
) + coefficient
*expression
211 def _fromast(cls
, node
):
212 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
213 return cls
._fromast
(node
.body
[0])
214 elif isinstance(node
, ast
.Expr
):
215 return cls
._fromast
(node
.value
)
216 elif isinstance(node
, ast
.Name
):
217 return Symbol(node
.id)
218 elif isinstance(node
, ast
.Num
):
219 return Rational(node
.n
)
220 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
221 return -cls
._fromast
(node
.operand
)
222 elif isinstance(node
, ast
.BinOp
):
223 left
= cls
._fromast
(node
.left
)
224 right
= cls
._fromast
(node
.right
)
225 if isinstance(node
.op
, ast
.Add
):
227 elif isinstance(node
.op
, ast
.Sub
):
229 elif isinstance(node
.op
, ast
.Mult
):
231 elif isinstance(node
.op
, ast
.Div
):
233 raise SyntaxError('invalid syntax')
235 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
238 def fromstring(cls
, string
):
239 # add implicit multiplication operators, e.g. '5x' -> '5*x'
240 string
= Expression
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
241 tree
= ast
.parse(string
, 'eval')
242 return cls
._fromast
(tree
)
246 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
250 elif coefficient
== -1:
251 string
+= '-' if i
== 0 else ' - '
253 string
+= '{}*'.format(coefficient
)
254 elif coefficient
> 0:
255 string
+= ' + {}*'.format(coefficient
)
257 string
+= ' - {}*'.format(-coefficient
)
258 string
+= '{}'.format(symbol
)
259 constant
= self
.constant
261 string
+= '{}'.format(constant
)
263 string
+= ' + {}'.format(constant
)
265 string
+= ' - {}'.format(-constant
)
268 def _repr_latex_(self
):
270 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
274 elif coefficient
== -1:
275 string
+= '-' if i
== 0 else ' - '
277 string
+= '{}'.format(coefficient
._repr
_latex
_().strip('$'))
278 elif coefficient
> 0:
279 string
+= ' + {}'.format(coefficient
._repr
_latex
_().strip('$'))
280 elif coefficient
< 0:
281 string
+= ' - {}'.format((-coefficient
)._repr
_latex
_().strip('$'))
282 string
+= '{}'.format(symbol
._repr
_latex
_().strip('$'))
283 constant
= self
.constant
285 string
+= '{}'.format(constant
._repr
_latex
_().strip('$'))
287 string
+= ' + {}'.format(constant
._repr
_latex
_().strip('$'))
289 string
+= ' - {}'.format((-constant
)._repr
_latex
_().strip('$'))
290 return '${}$'.format(string
)
292 def _parenstr(self
, always
=False):
294 if not always
and (self
.isconstant() or self
.issymbol()):
297 return '({})'.format(string
)
300 def fromsympy(cls
, expr
):
304 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
305 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
306 if symbol
== sympy
.S
.One
:
307 constant
= coefficient
308 elif isinstance(symbol
, sympy
.Symbol
):
309 symbol
= Symbol(symbol
.name
)
310 coefficients
.append((symbol
, coefficient
))
312 raise ValueError('non-linear expression: {!r}'.format(expr
))
313 return Expression(coefficients
, constant
)
318 for symbol
, coefficient
in self
.coefficients():
319 term
= coefficient
* sympy
.Symbol(symbol
.name
)
321 expr
+= self
.constant
325 class Symbol(Expression
):
327 def __new__(cls
, name
):
328 if not isinstance(name
, str):
329 raise TypeError('name must be a string')
330 self
= object().__new
__(cls
)
331 self
._name
= name
.strip()
332 self
._coefficients
= {self
: Fraction(1)}
333 self
._constant
= Fraction(0)
334 self
._symbols
= (self
,)
343 return hash(self
.sortkey())
351 def __eq__(self
, other
):
352 return not isinstance(other
, Dummy
) and isinstance(other
, Symbol
) \
353 and self
.name
== other
.name
356 return Dummy(self
.name
)
359 def _fromast(cls
, node
):
360 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
361 return cls
._fromast
(node
.body
[0])
362 elif isinstance(node
, ast
.Expr
):
363 return cls
._fromast
(node
.value
)
364 elif isinstance(node
, ast
.Name
):
365 return Symbol(node
.id)
366 raise SyntaxError('invalid syntax')
371 def _repr_latex_(self
):
372 return '${}$'.format(self
.name
)
375 def fromsympy(cls
, expr
):
377 if isinstance(expr
, sympy
.Symbol
):
378 return cls(expr
.name
)
380 raise TypeError('expr must be a sympy.Symbol instance')
387 def __new__(cls
, name
=None):
389 name
= 'Dummy_{}'.format(Dummy
._count
)
390 self
= object().__new
__(cls
)
391 self
._index
= Dummy
._count
392 self
._name
= name
.strip()
393 self
._coefficients
= {self
: Fraction(1)}
394 self
._constant
= Fraction(0)
395 self
._symbols
= (self
,)
401 return hash(self
.sortkey())
404 return self
._name
, self
._index
406 def __eq__(self
, other
):
407 return isinstance(other
, Dummy
) and self
._index
== other
._index
410 return '_{}'.format(self
.name
)
412 def _repr_latex_(self
):
413 return '${}_{{{}}}$'.format(self
.name
, self
._index
)
417 if isinstance(names
, str):
418 names
= names
.replace(',', ' ').split()
419 return tuple(Symbol(name
) for name
in names
)
422 class Rational(Expression
, Fraction
):
424 def __new__(cls
, numerator
=0, denominator
=None):
425 self
= Fraction
.__new
__(cls
, numerator
, denominator
)
426 self
._coefficients
= {}
427 self
._constant
= Fraction(self
)
433 return Fraction
.__hash
__(self
)
439 def isconstant(self
):
443 return Fraction
.__bool
__(self
)
446 def fromstring(cls
, string
):
447 if not isinstance(string
, str):
448 raise TypeError('string must be a string instance')
449 return Rational(string
)
452 if self
.denominator
== 1:
453 return '{!r}'.format(self
.numerator
)
455 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)
457 def _repr_latex_(self
):
458 if self
.denominator
== 1:
459 return '${}$'.format(self
.numerator
)
460 elif self
.numerator
< 0:
461 return '$-\\frac{{{}}}{{{}}}$'.format(-self
.numerator
,
464 return '$\\frac{{{}}}{{{}}}$'.format(self
.numerator
,
468 def fromsympy(cls
, expr
):
470 if isinstance(expr
, sympy
.Rational
):
471 return Rational(expr
.p
, expr
.q
)
472 elif isinstance(expr
, numbers
.Rational
):
473 return Rational(expr
)
475 raise TypeError('expr must be a sympy.Rational instance')