ab5d344c8006d9394500ec74d741d2692db5e34d
1 # Copyright 2014 MINES ParisTech
3 # This file is part of LinPy.
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
23 from collections
import OrderedDict
, defaultdict
, Mapping
24 from fractions
import Fraction
, gcd
29 'Symbol', 'Dummy', 'symbols',
34 def _polymorphic(func
):
35 @functools.wraps(func
)
36 def wrapper(left
, right
):
37 if isinstance(right
, LinExpr
):
38 return func(left
, right
)
39 elif isinstance(right
, numbers
.Rational
):
40 right
= Rational(right
)
41 return func(left
, right
)
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
56 LinExpr instances are hashable, and should be treated as immutable.
59 def __new__(cls
, coefficients
=None, constant
=0):
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
65 For example, the linear expression x + 2y + 1 can be constructed using
66 one of the following instructions:
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
72 However, it may be easier to use overloaded operators:
74 >>> x, y = symbols('x y')
77 Alternatively, linear expressions can be constructed from a string:
79 >>> LinExpr('x + 2*y + 1')
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
86 if isinstance(coefficients
, str):
88 raise TypeError('too many arguments')
89 return LinExpr
.fromstring(coefficients
)
90 if coefficients
is None:
91 return Rational(constant
)
92 if isinstance(coefficients
, Mapping
):
93 coefficients
= coefficients
.items()
94 coefficients
= list(coefficients
)
95 for symbol
, coefficient
in coefficients
:
96 if not isinstance(symbol
, Symbol
):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient
, numbers
.Rational
):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant
, numbers
.Rational
):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients
) == 0:
103 return Rational(constant
)
104 if len(coefficients
) == 1 and constant
== 0:
105 symbol
, coefficient
= coefficients
[0]
108 coefficients
= [(symbol
, Fraction(coefficient
))
109 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
110 coefficients
.sort(key
=lambda item
: item
[0].sortkey())
111 self
= object().__new
__(cls
)
112 self
._coefficients
= OrderedDict(coefficients
)
113 self
._constant
= Fraction(constant
)
114 self
._symbols
= tuple(self
._coefficients
)
115 self
._dimension
= len(self
._symbols
)
118 def coefficient(self
, symbol
):
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
123 if not isinstance(symbol
, Symbol
):
124 raise TypeError('symbol must be a Symbol instance')
125 return self
._coefficients
.get(symbol
, Fraction(0))
127 __getitem__
= coefficient
129 def coefficients(self
):
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
134 yield from self
._coefficients
.items()
139 The constant term of the expression.
141 return self
._constant
146 The tuple of symbols present in the expression, sorted according to
154 The dimension of the expression, i.e. the number of symbols present in
157 return self
._dimension
160 return hash((tuple(self
._coefficients
.items()), self
._constant
))
162 def isconstant(self
):
164 Return True if the expression only consists of a constant term. In this
165 case, it is a Rational instance.
171 Return True if an expression only consists of a symbol with coefficient
172 1. In this case, it is a Symbol instance.
178 Iterate over the coefficient values in the expression, and the constant
181 for coefficient
in self
._coefficients
.values():
195 def __add__(self
, other
):
197 Return the sum of two linear expressions.
199 coefficients
= defaultdict(Fraction
, self
._coefficients
)
200 for symbol
, coefficient
in other
._coefficients
.items():
201 coefficients
[symbol
] += coefficient
202 constant
= self
._constant
+ other
._constant
203 return LinExpr(coefficients
, constant
)
208 def __sub__(self
, other
):
210 Return the difference between two linear expressions.
212 coefficients
= defaultdict(Fraction
, self
._coefficients
)
213 for symbol
, coefficient
in other
._coefficients
.items():
214 coefficients
[symbol
] -= coefficient
215 constant
= self
._constant
- other
._constant
216 return LinExpr(coefficients
, constant
)
219 def __rsub__(self
, other
):
222 def __mul__(self
, other
):
224 Return the product of the linear expression by a rational.
226 if isinstance(other
, numbers
.Rational
):
227 coefficients
= ((symbol
, coefficient
* other
)
228 for symbol
, coefficient
in self
._coefficients
.items())
229 constant
= self
._constant
* other
230 return LinExpr(coefficients
, constant
)
231 return NotImplemented
235 def __truediv__(self
, other
):
237 Return the quotient of the linear expression by a rational.
239 if isinstance(other
, numbers
.Rational
):
240 coefficients
= ((symbol
, coefficient
/ other
)
241 for symbol
, coefficient
in self
._coefficients
.items())
242 constant
= self
._constant
/ other
243 return LinExpr(coefficients
, constant
)
244 return NotImplemented
247 def __eq__(self
, other
):
249 Test whether two linear expressions are equal.
251 return isinstance(other
, LinExpr
) and \
252 self
._coefficients
== other
._coefficients
and \
253 self
._constant
== other
._constant
255 def __le__(self
, other
):
256 from .polyhedra
import Le
257 return Le(self
, other
)
259 def __lt__(self
, other
):
260 from .polyhedra
import Lt
261 return Lt(self
, other
)
263 def __ge__(self
, other
):
264 from .polyhedra
import Ge
265 return Ge(self
, other
)
267 def __gt__(self
, other
):
268 from .polyhedra
import Gt
269 return Gt(self
, other
)
273 Return the expression multiplied by its lowest common denominator to
274 make all values integer.
276 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
277 [value
.denominator
for value
in self
.values()])
280 def subs(self
, symbol
, expression
=None):
282 Substitute the given symbol by an expression and return the resulting
283 expression. Raise TypeError if the resulting expression is not linear.
285 >>> x, y = symbols('x y')
290 To perform multiple substitutions at once, pass a sequence or a
291 dictionary of (old, new) pairs to subs.
293 >>> e.subs({x: y, y: x})
296 if expression
is None:
297 if isinstance(symbol
, Mapping
):
298 symbol
= symbol
.items()
299 substitutions
= symbol
301 substitutions
= [(symbol
, expression
)]
303 for symbol
, expression
in substitutions
:
304 if not isinstance(symbol
, Symbol
):
305 raise TypeError('symbols must be Symbol instances')
306 coefficients
= [(othersymbol
, coefficient
)
307 for othersymbol
, coefficient
in result
._coefficients
.items()
308 if othersymbol
!= symbol
]
309 coefficient
= result
._coefficients
.get(symbol
, 0)
310 constant
= result
._constant
311 result
= LinExpr(coefficients
, constant
) + coefficient
*expression
315 def _fromast(cls
, node
):
316 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
317 return cls
._fromast
(node
.body
[0])
318 elif isinstance(node
, ast
.Expr
):
319 return cls
._fromast
(node
.value
)
320 elif isinstance(node
, ast
.Name
):
321 return Symbol(node
.id)
322 elif isinstance(node
, ast
.Num
):
323 return Rational(node
.n
)
324 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
325 return -cls
._fromast
(node
.operand
)
326 elif isinstance(node
, ast
.BinOp
):
327 left
= cls
._fromast
(node
.left
)
328 right
= cls
._fromast
(node
.right
)
329 if isinstance(node
.op
, ast
.Add
):
331 elif isinstance(node
.op
, ast
.Sub
):
333 elif isinstance(node
.op
, ast
.Mult
):
335 elif isinstance(node
.op
, ast
.Div
):
337 raise SyntaxError('invalid syntax')
339 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
342 def fromstring(cls
, string
):
344 Create an expression from a string. Raise SyntaxError if the string is
345 not properly formatted.
347 # add implicit multiplication operators, e.g. '5x' -> '5*x'
348 string
= LinExpr
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
349 tree
= ast
.parse(string
, 'eval')
350 expr
= cls
._fromast
(tree
)
351 if not isinstance(expr
, cls
):
352 raise SyntaxError('invalid syntax')
357 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
361 elif coefficient
== -1:
362 string
+= '-' if i
== 0 else ' - '
364 string
+= '{}*'.format(coefficient
)
365 elif coefficient
> 0:
366 string
+= ' + {}*'.format(coefficient
)
368 string
+= ' - {}*'.format(-coefficient
)
369 string
+= '{}'.format(symbol
)
370 constant
= self
.constant
372 string
+= '{}'.format(constant
)
374 string
+= ' + {}'.format(constant
)
376 string
+= ' - {}'.format(-constant
)
379 def _repr_latex_(self
):
381 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
385 elif coefficient
== -1:
386 string
+= '-' if i
== 0 else ' - '
388 string
+= '{}'.format(coefficient
._repr
_latex
_().strip('$'))
389 elif coefficient
> 0:
390 string
+= ' + {}'.format(coefficient
._repr
_latex
_().strip('$'))
391 elif coefficient
< 0:
392 string
+= ' - {}'.format((-coefficient
)._repr
_latex
_().strip('$'))
393 string
+= '{}'.format(symbol
._repr
_latex
_().strip('$'))
394 constant
= self
.constant
396 string
+= '{}'.format(constant
._repr
_latex
_().strip('$'))
398 string
+= ' + {}'.format(constant
._repr
_latex
_().strip('$'))
400 string
+= ' - {}'.format((-constant
)._repr
_latex
_().strip('$'))
401 return '$${}$$'.format(string
)
403 def _parenstr(self
, always
=False):
405 if not always
and (self
.isconstant() or self
.issymbol()):
408 return '({})'.format(string
)
411 def fromsympy(cls
, expr
):
413 Create a linear expression from a sympy expression. Raise TypeError is
414 the sympy expression is not linear.
419 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
420 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
421 if symbol
== sympy
.S
.One
:
422 constant
= coefficient
423 elif isinstance(symbol
, sympy
.Dummy
):
424 # we cannot properly convert dummy symbols
425 raise TypeError('cannot convert dummy symbols')
426 elif isinstance(symbol
, sympy
.Symbol
):
427 symbol
= Symbol(symbol
.name
)
428 coefficients
.append((symbol
, coefficient
))
430 raise TypeError('non-linear expression: {!r}'.format(expr
))
431 expr
= LinExpr(coefficients
, constant
)
432 if not isinstance(expr
, cls
):
433 raise TypeError('cannot convert to a {} instance'.format(cls
.__name
__))
438 Convert the linear expression to a sympy expression.
442 for symbol
, coefficient
in self
.coefficients():
443 term
= coefficient
* sympy
.Symbol(symbol
.name
)
445 expr
+= self
.constant
449 class Symbol(LinExpr
):
451 Symbols are the basic components to build expressions and constraints.
452 They correspond to mathematical variables. Symbols are instances of
453 class LinExpr and inherit its functionalities.
455 Two instances of Symbol are equal if they have the same name.
458 def __new__(cls
, name
):
460 Return a symbol with the name string given in argument.
462 if not isinstance(name
, str):
463 raise TypeError('name must be a string')
464 node
= ast
.parse(name
)
466 name
= node
.body
[0].value
.id
467 except (AttributeError, SyntaxError):
468 raise SyntaxError('invalid syntax')
469 self
= object().__new
__(cls
)
471 self
._coefficients
= {self
: Fraction(1)}
472 self
._constant
= Fraction(0)
473 self
._symbols
= (self
,)
480 The name of the symbol.
485 return hash(self
.sortkey())
489 Return a sorting key for the symbol. It is useful to sort a list of
490 symbols in a consistent order, as comparison functions are overridden
491 (see the documentation of class LinExpr).
493 >>> sort(symbols, key=Symbol.sortkey)
500 def __eq__(self
, other
):
501 return self
.sortkey() == other
.sortkey()
505 Return a new Dummy symbol instance with the same name.
507 return Dummy(self
.name
)
512 def _repr_latex_(self
):
513 return '$${}$$'.format(self
.name
)
518 This function returns a tuple of symbols whose names are taken from a comma
519 or whitespace delimited string, or a sequence of strings. It is useful to
520 define several symbols at once.
522 >>> x, y = symbols('x y')
523 >>> x, y = symbols('x, y')
524 >>> x, y = symbols(['x', 'y'])
526 if isinstance(names
, str):
527 names
= names
.replace(',', ' ').split()
528 return tuple(Symbol(name
) for name
in names
)
533 A variation of Symbol in which all symbols are unique and identified by
534 an internal count index. If a name is not supplied then a string value
535 of the count index will be used. This is useful when a unique, temporary
536 variable is needed and the name of the variable used in the expression
539 Unlike Symbol, Dummy instances with the same name are not equal:
542 >>> x1, x2 = Dummy('x'), Dummy('x')
553 def __new__(cls
, name
=None):
555 Return a fresh dummy symbol with the name string given in argument.
558 name
= 'Dummy_{}'.format(Dummy
._count
)
559 elif not isinstance(name
, str):
560 raise TypeError('name must be a string')
561 self
= object().__new
__(cls
)
562 self
._index
= Dummy
._count
563 self
._name
= name
.strip()
564 self
._coefficients
= {self
: Fraction(1)}
565 self
._constant
= Fraction(0)
566 self
._symbols
= (self
,)
572 return hash(self
.sortkey())
575 return self
._name
, self
._index
578 return '_{}'.format(self
.name
)
580 def _repr_latex_(self
):
581 return '$${}_{{{}}}$$'.format(self
.name
, self
._index
)
584 class Rational(LinExpr
, Fraction
):
586 A particular case of linear expressions are rational values, i.e. linear
587 expressions consisting only of a constant term, with no symbol. They are
588 implemented by the Rational class, that inherits from both LinExpr and
589 fractions.Fraction classes.
592 def __new__(cls
, numerator
=0, denominator
=None):
593 self
= object().__new
__(cls
)
594 self
._coefficients
= {}
595 self
._constant
= Fraction(numerator
, denominator
)
598 self
._numerator
= self
._constant
.numerator
599 self
._denominator
= self
._constant
.denominator
603 return Fraction
.__hash
__(self
)
609 def isconstant(self
):
613 return Fraction
.__bool
__(self
)
616 if self
.denominator
== 1:
617 return '{!r}'.format(self
.numerator
)
619 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)
621 def _repr_latex_(self
):
622 if self
.denominator
== 1:
623 return '$${}$$'.format(self
.numerator
)
624 elif self
.numerator
< 0:
625 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self
.numerator
,
628 return '$$\\frac{{{}}}{{{}}}$$'.format(self
.numerator
,