1 # Copyright 2014 MINES ParisTech
3 # This file is part of LinPy.
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
23 from collections
import OrderedDict
, defaultdict
, Mapping
24 from fractions
import Fraction
, gcd
29 'Symbol', 'Dummy', 'symbols',
34 def _polymorphic(func
):
35 @functools.wraps(func
)
36 def wrapper(left
, right
):
37 if isinstance(right
, Expression
):
38 return func(left
, right
)
39 elif isinstance(right
, numbers
.Rational
):
40 right
= Rational(right
)
41 return func(left
, right
)
48 This class implements linear expressions.
51 def __new__(cls
, coefficients
=None, constant
=0):
53 Create a new expression.
55 if isinstance(coefficients
, str):
57 raise TypeError('too many arguments')
58 return Expression
.fromstring(coefficients
)
59 if coefficients
is None:
60 return Rational(constant
)
61 if isinstance(coefficients
, Mapping
):
62 coefficients
= coefficients
.items()
63 coefficients
= list(coefficients
)
64 for symbol
, coefficient
in coefficients
:
65 if not isinstance(symbol
, Symbol
):
66 raise TypeError('symbols must be Symbol instances')
67 if not isinstance(coefficient
, numbers
.Rational
):
68 raise TypeError('coefficients must be rational numbers')
69 if not isinstance(constant
, numbers
.Rational
):
70 raise TypeError('constant must be a rational number')
71 if len(coefficients
) == 0:
72 return Rational(constant
)
73 if len(coefficients
) == 1 and constant
== 0:
74 symbol
, coefficient
= coefficients
[0]
77 coefficients
= [(symbol
, Fraction(coefficient
))
78 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
79 coefficients
.sort(key
=lambda item
: item
[0].sortkey())
80 self
= object().__new
__(cls
)
81 self
._coefficients
= OrderedDict(coefficients
)
82 self
._constant
= Fraction(constant
)
83 self
._symbols
= tuple(self
._coefficients
)
84 self
._dimension
= len(self
._symbols
)
87 def coefficient(self
, symbol
):
89 Return the coefficient value of the given symbol.
91 if not isinstance(symbol
, Symbol
):
92 raise TypeError('symbol must be a Symbol instance')
93 return Rational(self
._coefficients
.get(symbol
, 0))
95 __getitem__
= coefficient
97 def coefficients(self
):
99 Return a list of the coefficients of an expression
101 for symbol
, coefficient
in self
._coefficients
.items():
102 yield symbol
, Rational(coefficient
)
107 Return the constant value of an expression.
109 return Rational(self
._constant
)
114 Return a list of symbols in an expression.
121 Create and return a new linear expression from a string or a list of coefficients and a constant.
123 return self
._dimension
126 return hash((tuple(self
._coefficients
.items()), self
._constant
))
128 def isconstant(self
):
130 Return true if an expression is a constant.
136 Return true if an expression is a symbol.
142 Return the coefficient and constant values of an expression.
144 for coefficient
in self
._coefficients
.values():
145 yield Rational(coefficient
)
146 yield Rational(self
._constant
)
158 def __add__(self
, other
):
160 Return the sum of two expressions.
162 coefficients
= defaultdict(Fraction
, self
._coefficients
)
163 for symbol
, coefficient
in other
._coefficients
.items():
164 coefficients
[symbol
] += coefficient
165 constant
= self
._constant
+ other
._constant
166 return Expression(coefficients
, constant
)
171 def __sub__(self
, other
):
173 Return the difference between two expressions.
175 coefficients
= defaultdict(Fraction
, self
._coefficients
)
176 for symbol
, coefficient
in other
._coefficients
.items():
177 coefficients
[symbol
] -= coefficient
178 constant
= self
._constant
- other
._constant
179 return Expression(coefficients
, constant
)
182 def __rsub__(self
, other
):
185 def __mul__(self
, other
):
187 Return the product of two expressions if other is a rational number.
189 if isinstance(other
, numbers
.Rational
):
190 coefficients
= ((symbol
, coefficient
* other
)
191 for symbol
, coefficient
in self
._coefficients
.items())
192 constant
= self
._constant
* other
193 return Expression(coefficients
, constant
)
194 return NotImplemented
198 def __truediv__(self
, other
):
199 if isinstance(other
, numbers
.Rational
):
200 coefficients
= ((symbol
, coefficient
/ other
)
201 for symbol
, coefficient
in self
._coefficients
.items())
202 constant
= self
._constant
/ other
203 return Expression(coefficients
, constant
)
204 return NotImplemented
207 def __eq__(self
, other
):
209 Test whether two expressions are equal
211 return isinstance(other
, Expression
) and \
212 self
._coefficients
== other
._coefficients
and \
213 self
._constant
== other
._constant
215 def __le__(self
, other
):
216 from .polyhedra
import Le
217 return Le(self
, other
)
219 def __lt__(self
, other
):
220 from .polyhedra
import Lt
221 return Lt(self
, other
)
223 def __ge__(self
, other
):
224 from .polyhedra
import Ge
225 return Ge(self
, other
)
227 def __gt__(self
, other
):
228 from .polyhedra
import Gt
229 return Gt(self
, other
)
233 Multiply an expression by a scalar to make all coefficients integer values.
235 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
236 [value
.denominator
for value
in self
.values()])
239 def subs(self
, symbol
, expression
=None):
241 Subsitute symbol by expression in equations and return the resulting
244 if expression
is None:
245 if isinstance(symbol
, Mapping
):
246 symbol
= symbol
.items()
247 substitutions
= symbol
249 substitutions
= [(symbol
, expression
)]
251 for symbol
, expression
in substitutions
:
252 if not isinstance(symbol
, Symbol
):
253 raise TypeError('symbols must be Symbol instances')
254 coefficients
= [(othersymbol
, coefficient
)
255 for othersymbol
, coefficient
in result
._coefficients
.items()
256 if othersymbol
!= symbol
]
257 coefficient
= result
._coefficients
.get(symbol
, 0)
258 constant
= result
._constant
259 result
= Expression(coefficients
, constant
) + coefficient
*expression
263 def _fromast(cls
, node
):
264 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
265 return cls
._fromast
(node
.body
[0])
266 elif isinstance(node
, ast
.Expr
):
267 return cls
._fromast
(node
.value
)
268 elif isinstance(node
, ast
.Name
):
269 return Symbol(node
.id)
270 elif isinstance(node
, ast
.Num
):
271 return Rational(node
.n
)
272 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
273 return -cls
._fromast
(node
.operand
)
274 elif isinstance(node
, ast
.BinOp
):
275 left
= cls
._fromast
(node
.left
)
276 right
= cls
._fromast
(node
.right
)
277 if isinstance(node
.op
, ast
.Add
):
279 elif isinstance(node
.op
, ast
.Sub
):
281 elif isinstance(node
.op
, ast
.Mult
):
283 elif isinstance(node
.op
, ast
.Div
):
285 raise SyntaxError('invalid syntax')
287 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
290 def fromstring(cls
, string
):
292 Create an expression from a string.
294 # add implicit multiplication operators, e.g. '5x' -> '5*x'
295 string
= Expression
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
296 tree
= ast
.parse(string
, 'eval')
297 return cls
._fromast
(tree
)
301 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
305 elif coefficient
== -1:
306 string
+= '-' if i
== 0 else ' - '
308 string
+= '{}*'.format(coefficient
)
309 elif coefficient
> 0:
310 string
+= ' + {}*'.format(coefficient
)
312 string
+= ' - {}*'.format(-coefficient
)
313 string
+= '{}'.format(symbol
)
314 constant
= self
.constant
316 string
+= '{}'.format(constant
)
318 string
+= ' + {}'.format(constant
)
320 string
+= ' - {}'.format(-constant
)
323 def _repr_latex_(self
):
325 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
329 elif coefficient
== -1:
330 string
+= '-' if i
== 0 else ' - '
332 string
+= '{}'.format(coefficient
._repr
_latex
_().strip('$'))
333 elif coefficient
> 0:
334 string
+= ' + {}'.format(coefficient
._repr
_latex
_().strip('$'))
335 elif coefficient
< 0:
336 string
+= ' - {}'.format((-coefficient
)._repr
_latex
_().strip('$'))
337 string
+= '{}'.format(symbol
._repr
_latex
_().strip('$'))
338 constant
= self
.constant
340 string
+= '{}'.format(constant
._repr
_latex
_().strip('$'))
342 string
+= ' + {}'.format(constant
._repr
_latex
_().strip('$'))
344 string
+= ' - {}'.format((-constant
)._repr
_latex
_().strip('$'))
345 return '$${}$$'.format(string
)
347 def _parenstr(self
, always
=False):
349 if not always
and (self
.isconstant() or self
.issymbol()):
352 return '({})'.format(string
)
355 def fromsympy(cls
, expr
):
357 Convert sympy object to an expression.
362 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
363 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
364 if symbol
== sympy
.S
.One
:
365 constant
= coefficient
366 elif isinstance(symbol
, sympy
.Symbol
):
367 symbol
= Symbol(symbol
.name
)
368 coefficients
.append((symbol
, coefficient
))
370 raise ValueError('non-linear expression: {!r}'.format(expr
))
371 return Expression(coefficients
, constant
)
375 Return an expression as a sympy object.
379 for symbol
, coefficient
in self
.coefficients():
380 term
= coefficient
* sympy
.Symbol(symbol
.name
)
382 expr
+= self
.constant
386 class Symbol(Expression
):
388 def __new__(cls
, name
):
390 Create and return a symbol from a string.
392 if not isinstance(name
, str):
393 raise TypeError('name must be a string')
394 self
= object().__new
__(cls
)
395 self
._name
= name
.strip()
396 self
._coefficients
= {self
: Fraction(1)}
397 self
._constant
= Fraction(0)
398 self
._symbols
= (self
,)
407 return hash(self
.sortkey())
415 def __eq__(self
, other
):
416 return self
.sortkey() == other
.sortkey()
420 Return a symbol as a Dummy Symbol.
422 return Dummy(self
.name
)
425 def _fromast(cls
, node
):
426 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
427 return cls
._fromast
(node
.body
[0])
428 elif isinstance(node
, ast
.Expr
):
429 return cls
._fromast
(node
.value
)
430 elif isinstance(node
, ast
.Name
):
431 return Symbol(node
.id)
432 raise SyntaxError('invalid syntax')
437 def _repr_latex_(self
):
438 return '$${}$$'.format(self
.name
)
441 def fromsympy(cls
, expr
):
443 if isinstance(expr
, sympy
.Dummy
):
444 return Dummy(expr
.name
)
445 elif isinstance(expr
, sympy
.Symbol
):
446 return Symbol(expr
.name
)
448 raise TypeError('expr must be a sympy.Symbol instance')
453 This class returns a dummy symbol to ensure that no variables are repeated in an expression
457 def __new__(cls
, name
=None):
459 Create and return a new dummy symbol.
462 name
= 'Dummy_{}'.format(Dummy
._count
)
463 elif not isinstance(name
, str):
464 raise TypeError('name must be a string')
465 self
= object().__new
__(cls
)
466 self
._index
= Dummy
._count
467 self
._name
= name
.strip()
468 self
._coefficients
= {self
: Fraction(1)}
469 self
._constant
= Fraction(0)
470 self
._symbols
= (self
,)
476 return hash(self
.sortkey())
479 return self
._name
, self
._index
482 return '_{}'.format(self
.name
)
484 def _repr_latex_(self
):
485 return '$${}_{{{}}}$$'.format(self
.name
, self
._index
)
490 Transform strings into instances of the Symbol class
492 if isinstance(names
, str):
493 names
= names
.replace(',', ' ').split()
494 return tuple(Symbol(name
) for name
in names
)
497 class Rational(Expression
, Fraction
):
499 This class represents integers and rational numbers of any size.
502 def __new__(cls
, numerator
=0, denominator
=None):
503 self
= object().__new
__(cls
)
504 self
._coefficients
= {}
505 self
._constant
= Fraction(numerator
, denominator
)
508 self
._numerator
= self
._constant
.numerator
509 self
._denominator
= self
._constant
.denominator
513 return Fraction
.__hash
__(self
)
518 Return rational as a constant.
522 def isconstant(self
):
524 Test whether a value is a constant.
529 return Fraction
.__bool
__(self
)
532 if self
.denominator
== 1:
533 return '{!r}'.format(self
.numerator
)
535 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)
537 def _repr_latex_(self
):
538 if self
.denominator
== 1:
539 return '$${}$$'.format(self
.numerator
)
540 elif self
.numerator
< 0:
541 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self
.numerator
,
544 return '$$\\frac{{{}}}{{{}}}$$'.format(self
.numerator
,
548 def fromsympy(cls
, expr
):
550 Create a rational object from a sympy expression
553 if isinstance(expr
, sympy
.Rational
):
554 return Rational(expr
.p
, expr
.q
)
555 elif isinstance(expr
, numbers
.Rational
):
556 return Rational(expr
)
558 raise TypeError('expr must be a sympy.Rational instance')