1 # Copyright 2014 MINES ParisTech
3 # This file is part of LinPy.
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
23 from collections
import OrderedDict
, defaultdict
, Mapping
24 from fractions
import Fraction
, gcd
29 'Symbol', 'Dummy', 'symbols',
34 def _polymorphic(func
):
35 @functools.wraps(func
)
36 def wrapper(left
, right
):
37 if isinstance(right
, Expression
):
38 return func(left
, right
)
39 elif isinstance(right
, numbers
.Rational
):
40 right
= Rational(right
)
41 return func(left
, right
)
48 This class implements linear expressions.
51 def __new__(cls
, coefficients
=None, constant
=0):
52 if isinstance(coefficients
, str):
54 raise TypeError('too many arguments')
55 return Expression
.fromstring(coefficients
)
56 if coefficients
is None:
57 return Rational(constant
)
58 if isinstance(coefficients
, Mapping
):
59 coefficients
= coefficients
.items()
60 coefficients
= list(coefficients
)
61 for symbol
, coefficient
in coefficients
:
62 if not isinstance(symbol
, Symbol
):
63 raise TypeError('symbols must be Symbol instances')
64 if not isinstance(coefficient
, numbers
.Rational
):
65 raise TypeError('coefficients must be rational numbers')
66 if not isinstance(constant
, numbers
.Rational
):
67 raise TypeError('constant must be a rational number')
68 if len(coefficients
) == 0:
69 return Rational(constant
)
70 if len(coefficients
) == 1 and constant
== 0:
71 symbol
, coefficient
= coefficients
[0]
74 coefficients
= [(symbol
, Fraction(coefficient
))
75 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
76 coefficients
.sort(key
=lambda item
: item
[0].sortkey())
77 self
= object().__new
__(cls
)
78 self
._coefficients
= OrderedDict(coefficients
)
79 self
._constant
= Fraction(constant
)
80 self
._symbols
= tuple(self
._coefficients
)
81 self
._dimension
= len(self
._symbols
)
84 def coefficient(self
, symbol
):
85 if not isinstance(symbol
, Symbol
):
86 raise TypeError('symbol must be a Symbol instance')
87 return Rational(self
._coefficients
.get(symbol
, 0))
89 __getitem__
= coefficient
91 def coefficients(self
):
92 for symbol
, coefficient
in self
._coefficients
.items():
93 yield symbol
, Rational(coefficient
)
97 return Rational(self
._constant
)
105 return self
._dimension
108 return hash((tuple(self
._coefficients
.items()), self
._constant
))
110 def isconstant(self
):
117 for coefficient
in self
._coefficients
.values():
118 yield Rational(coefficient
)
119 yield Rational(self
._constant
)
131 def __add__(self
, other
):
132 coefficients
= defaultdict(Fraction
, self
._coefficients
)
133 for symbol
, coefficient
in other
._coefficients
.items():
134 coefficients
[symbol
] += coefficient
135 constant
= self
._constant
+ other
._constant
136 return Expression(coefficients
, constant
)
141 def __sub__(self
, other
):
142 coefficients
= defaultdict(Fraction
, self
._coefficients
)
143 for symbol
, coefficient
in other
._coefficients
.items():
144 coefficients
[symbol
] -= coefficient
145 constant
= self
._constant
- other
._constant
146 return Expression(coefficients
, constant
)
149 def __rsub__(self
, other
):
152 def __mul__(self
, other
):
153 if isinstance(other
, numbers
.Rational
):
154 coefficients
= ((symbol
, coefficient
* other
)
155 for symbol
, coefficient
in self
._coefficients
.items())
156 constant
= self
._constant
* other
157 return Expression(coefficients
, constant
)
158 return NotImplemented
162 def __truediv__(self
, other
):
163 if isinstance(other
, numbers
.Rational
):
164 coefficients
= ((symbol
, coefficient
/ other
)
165 for symbol
, coefficient
in self
._coefficients
.items())
166 constant
= self
._constant
/ other
167 return Expression(coefficients
, constant
)
168 return NotImplemented
171 def __eq__(self
, other
):
172 # returns a boolean, not a constraint
173 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
174 return isinstance(other
, Expression
) and \
175 self
._coefficients
== other
._coefficients
and \
176 self
._constant
== other
._constant
178 def __le__(self
, other
):
179 from .polyhedra
import Le
180 return Le(self
, other
)
182 def __lt__(self
, other
):
183 from .polyhedra
import Lt
184 return Lt(self
, other
)
186 def __ge__(self
, other
):
187 from .polyhedra
import Ge
188 return Ge(self
, other
)
190 def __gt__(self
, other
):
191 from .polyhedra
import Gt
192 return Gt(self
, other
)
195 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
196 [value
.denominator
for value
in self
.values()])
199 def subs(self
, symbol
, expression
=None):
200 if expression
is None:
201 if isinstance(symbol
, Mapping
):
202 symbol
= symbol
.items()
203 substitutions
= symbol
205 substitutions
= [(symbol
, expression
)]
207 for symbol
, expression
in substitutions
:
208 if not isinstance(symbol
, Symbol
):
209 raise TypeError('symbols must be Symbol instances')
210 coefficients
= [(othersymbol
, coefficient
)
211 for othersymbol
, coefficient
in result
._coefficients
.items()
212 if othersymbol
!= symbol
]
213 coefficient
= result
._coefficients
.get(symbol
, 0)
214 constant
= result
._constant
215 result
= Expression(coefficients
, constant
) + coefficient
*expression
219 def _fromast(cls
, node
):
220 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
221 return cls
._fromast
(node
.body
[0])
222 elif isinstance(node
, ast
.Expr
):
223 return cls
._fromast
(node
.value
)
224 elif isinstance(node
, ast
.Name
):
225 return Symbol(node
.id)
226 elif isinstance(node
, ast
.Num
):
227 return Rational(node
.n
)
228 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
229 return -cls
._fromast
(node
.operand
)
230 elif isinstance(node
, ast
.BinOp
):
231 left
= cls
._fromast
(node
.left
)
232 right
= cls
._fromast
(node
.right
)
233 if isinstance(node
.op
, ast
.Add
):
235 elif isinstance(node
.op
, ast
.Sub
):
237 elif isinstance(node
.op
, ast
.Mult
):
239 elif isinstance(node
.op
, ast
.Div
):
241 raise SyntaxError('invalid syntax')
243 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
246 def fromstring(cls
, string
):
247 # add implicit multiplication operators, e.g. '5x' -> '5*x'
248 string
= Expression
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
249 tree
= ast
.parse(string
, 'eval')
250 return cls
._fromast
(tree
)
254 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
258 elif coefficient
== -1:
259 string
+= '-' if i
== 0 else ' - '
261 string
+= '{}*'.format(coefficient
)
262 elif coefficient
> 0:
263 string
+= ' + {}*'.format(coefficient
)
265 string
+= ' - {}*'.format(-coefficient
)
266 string
+= '{}'.format(symbol
)
267 constant
= self
.constant
269 string
+= '{}'.format(constant
)
271 string
+= ' + {}'.format(constant
)
273 string
+= ' - {}'.format(-constant
)
276 def _repr_latex_(self
):
278 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
282 elif coefficient
== -1:
283 string
+= '-' if i
== 0 else ' - '
285 string
+= '{}'.format(coefficient
._repr
_latex
_().strip('$'))
286 elif coefficient
> 0:
287 string
+= ' + {}'.format(coefficient
._repr
_latex
_().strip('$'))
288 elif coefficient
< 0:
289 string
+= ' - {}'.format((-coefficient
)._repr
_latex
_().strip('$'))
290 string
+= '{}'.format(symbol
._repr
_latex
_().strip('$'))
291 constant
= self
.constant
293 string
+= '{}'.format(constant
._repr
_latex
_().strip('$'))
295 string
+= ' + {}'.format(constant
._repr
_latex
_().strip('$'))
297 string
+= ' - {}'.format((-constant
)._repr
_latex
_().strip('$'))
298 return '$${}$$'.format(string
)
300 def _parenstr(self
, always
=False):
302 if not always
and (self
.isconstant() or self
.issymbol()):
305 return '({})'.format(string
)
308 def fromsympy(cls
, expr
):
312 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
313 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
314 if symbol
== sympy
.S
.One
:
315 constant
= coefficient
316 elif isinstance(symbol
, sympy
.Symbol
):
317 symbol
= Symbol(symbol
.name
)
318 coefficients
.append((symbol
, coefficient
))
320 raise ValueError('non-linear expression: {!r}'.format(expr
))
321 return Expression(coefficients
, constant
)
326 for symbol
, coefficient
in self
.coefficients():
327 term
= coefficient
* sympy
.Symbol(symbol
.name
)
329 expr
+= self
.constant
333 class Symbol(Expression
):
335 def __new__(cls
, name
):
336 if not isinstance(name
, str):
337 raise TypeError('name must be a string')
338 self
= object().__new
__(cls
)
339 self
._name
= name
.strip()
340 self
._coefficients
= {self
: Fraction(1)}
341 self
._constant
= Fraction(0)
342 self
._symbols
= (self
,)
351 return hash(self
.sortkey())
359 def __eq__(self
, other
):
360 return self
.sortkey() == other
.sortkey()
363 return Dummy(self
.name
)
366 def _fromast(cls
, node
):
367 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
368 return cls
._fromast
(node
.body
[0])
369 elif isinstance(node
, ast
.Expr
):
370 return cls
._fromast
(node
.value
)
371 elif isinstance(node
, ast
.Name
):
372 return Symbol(node
.id)
373 raise SyntaxError('invalid syntax')
378 def _repr_latex_(self
):
379 return '$${}$$'.format(self
.name
)
382 def fromsympy(cls
, expr
):
384 if isinstance(expr
, sympy
.Dummy
):
385 return Dummy(expr
.name
)
386 elif isinstance(expr
, sympy
.Symbol
):
387 return Symbol(expr
.name
)
389 raise TypeError('expr must be a sympy.Symbol instance')
396 def __new__(cls
, name
=None):
398 name
= 'Dummy_{}'.format(Dummy
._count
)
399 elif not isinstance(name
, str):
400 raise TypeError('name must be a string')
401 self
= object().__new
__(cls
)
402 self
._index
= Dummy
._count
403 self
._name
= name
.strip()
404 self
._coefficients
= {self
: Fraction(1)}
405 self
._constant
= Fraction(0)
406 self
._symbols
= (self
,)
412 return hash(self
.sortkey())
415 return self
._name
, self
._index
418 return '_{}'.format(self
.name
)
420 def _repr_latex_(self
):
421 return '$${}_{{{}}}$$'.format(self
.name
, self
._index
)
425 if isinstance(names
, str):
426 names
= names
.replace(',', ' ').split()
427 return tuple(Symbol(name
) for name
in names
)
430 class Rational(Expression
, Fraction
):
432 def __new__(cls
, numerator
=0, denominator
=None):
433 self
= object().__new
__(cls
)
434 self
._coefficients
= {}
435 self
._constant
= Fraction(numerator
, denominator
)
438 self
._numerator
= self
._constant
.numerator
439 self
._denominator
= self
._constant
.denominator
443 return Fraction
.__hash
__(self
)
449 def isconstant(self
):
453 return Fraction
.__bool
__(self
)
456 if self
.denominator
== 1:
457 return '{!r}'.format(self
.numerator
)
459 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)
461 def _repr_latex_(self
):
462 if self
.denominator
== 1:
463 return '$${}$$'.format(self
.numerator
)
464 elif self
.numerator
< 0:
465 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self
.numerator
,
468 return '$$\\frac{{{}}}{{{}}}$$'.format(self
.numerator
,
472 def fromsympy(cls
, expr
):
474 if isinstance(expr
, sympy
.Rational
):
475 return Rational(expr
.p
, expr
.q
)
476 elif isinstance(expr
, numbers
.Rational
):
477 return Rational(expr
)
479 raise TypeError('expr must be a sympy.Rational instance')