1 # Copyright 2014 MINES ParisTech
3 # This file is part of LinPy.
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
23 from collections
import OrderedDict
, defaultdict
, Mapping
24 from fractions
import Fraction
, gcd
29 'Symbol', 'Dummy', 'symbols',
34 def _polymorphic(func
):
35 @functools.wraps(func
)
36 def wrapper(left
, right
):
37 if isinstance(right
, LinExpr
):
38 return func(left
, right
)
39 elif isinstance(right
, numbers
.Rational
):
40 right
= Rational(right
)
41 return func(left
, right
)
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
56 LinExpr instances are hashable, and should be treated as immutable.
59 def __new__(cls
, coefficients
=None, constant
=0):
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
65 For example, the linear expression x + 2y + 1 can be constructed using
66 one of the following instructions:
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
72 However, it may be easier to use overloaded operators:
74 >>> x, y = symbols('x y')
77 Alternatively, linear expressions can be constructed from a string:
79 >>> LinExpr('x + 2*y + 1')
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
86 if isinstance(coefficients
, str):
88 raise TypeError('too many arguments')
89 return LinExpr
.fromstring(coefficients
)
90 if coefficients
is None:
91 return Rational(constant
)
92 if isinstance(coefficients
, Mapping
):
93 coefficients
= coefficients
.items()
94 coefficients
= list(coefficients
)
95 for symbol
, coefficient
in coefficients
:
96 if not isinstance(symbol
, Symbol
):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient
, numbers
.Rational
):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant
, numbers
.Rational
):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients
) == 0:
103 return Rational(constant
)
104 if len(coefficients
) == 1 and constant
== 0:
105 symbol
, coefficient
= coefficients
[0]
108 coefficients
= [(symbol
, Fraction(coefficient
))
109 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
110 coefficients
.sort(key
=lambda item
: item
[0].sortkey())
111 self
= object().__new
__(cls
)
112 self
._coefficients
= OrderedDict(coefficients
)
113 self
._constant
= Fraction(constant
)
114 self
._symbols
= tuple(self
._coefficients
)
115 self
._dimension
= len(self
._symbols
)
118 def coefficient(self
, symbol
):
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
123 if not isinstance(symbol
, Symbol
):
124 raise TypeError('symbol must be a Symbol instance')
125 return self
._coefficients
.get(symbol
, Fraction(0))
127 __getitem__
= coefficient
129 def coefficients(self
):
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
134 yield from self
._coefficients
.items()
139 The constant term of the expression.
141 return self
._constant
146 The tuple of symbols present in the expression, sorted according to
154 The dimension of the expression, i.e. the number of symbols present in
157 return self
._dimension
160 return hash((tuple(self
._coefficients
.items()), self
._constant
))
162 def isconstant(self
):
164 Return True if the expression only consists of a constant term. In this
165 case, it is a Rational instance.
171 Return True if an expression only consists of a symbol with coefficient
172 1. In this case, it is a Symbol instance.
178 Iterate over the coefficient values in the expression, and the constant
181 yield from self
._coefficients
.values()
194 def __add__(self
, other
):
196 Return the sum of two linear expressions.
198 coefficients
= defaultdict(Fraction
, self
._coefficients
)
199 for symbol
, coefficient
in other
._coefficients
.items():
200 coefficients
[symbol
] += coefficient
201 constant
= self
._constant
+ other
._constant
202 return LinExpr(coefficients
, constant
)
207 def __sub__(self
, other
):
209 Return the difference between two linear expressions.
211 coefficients
= defaultdict(Fraction
, self
._coefficients
)
212 for symbol
, coefficient
in other
._coefficients
.items():
213 coefficients
[symbol
] -= coefficient
214 constant
= self
._constant
- other
._constant
215 return LinExpr(coefficients
, constant
)
218 def __rsub__(self
, other
):
221 def __mul__(self
, other
):
223 Return the product of the linear expression by a rational.
225 if isinstance(other
, numbers
.Rational
):
226 coefficients
= ((symbol
, coefficient
* other
)
227 for symbol
, coefficient
in self
._coefficients
.items())
228 constant
= self
._constant
* other
229 return LinExpr(coefficients
, constant
)
230 return NotImplemented
234 def __truediv__(self
, other
):
236 Return the quotient of the linear expression by a rational.
238 if isinstance(other
, numbers
.Rational
):
239 coefficients
= ((symbol
, coefficient
/ other
)
240 for symbol
, coefficient
in self
._coefficients
.items())
241 constant
= self
._constant
/ other
242 return LinExpr(coefficients
, constant
)
243 return NotImplemented
246 def __eq__(self
, other
):
248 Test whether two linear expressions are equal.
250 return self
._coefficients
== other
._coefficients
and \
251 self
._constant
== other
._constant
254 def __lt__(self
, other
):
255 from .polyhedra
import Polyhedron
256 return Polyhedron([], [other
- self
- 1])
259 def __le__(self
, other
):
260 from .polyhedra
import Polyhedron
261 return Polyhedron([], [other
- self
])
264 def __ge__(self
, other
):
265 from .polyhedra
import Polyhedron
266 return Polyhedron([], [self
- other
])
269 def __gt__(self
, other
):
270 from .polyhedra
import Polyhedron
271 return Polyhedron([], [self
- other
- 1])
275 Return the expression multiplied by its lowest common denominator to
276 make all values integer.
278 lcd
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
279 [value
.denominator
for value
in self
.values()])
282 def subs(self
, symbol
, expression
=None):
284 Substitute the given symbol by an expression and return the resulting
285 expression. Raise TypeError if the resulting expression is not linear.
287 >>> x, y = symbols('x y')
292 To perform multiple substitutions at once, pass a sequence or a
293 dictionary of (old, new) pairs to subs.
295 >>> e.subs({x: y, y: x})
298 if expression
is None:
299 substitutions
= dict(symbol
)
301 substitutions
= {symbol
: expression
}
302 for symbol
in substitutions
:
303 if not isinstance(symbol
, Symbol
):
304 raise TypeError('symbols must be Symbol instances')
305 result
= self
._constant
306 for symbol
, coefficient
in self
._coefficients
.items():
307 expression
= substitutions
.get(symbol
, symbol
)
308 result
+= coefficient
* expression
312 def _fromast(cls
, node
):
313 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
314 return cls
._fromast
(node
.body
[0])
315 elif isinstance(node
, ast
.Expr
):
316 return cls
._fromast
(node
.value
)
317 elif isinstance(node
, ast
.Name
):
318 return Symbol(node
.id)
319 elif isinstance(node
, ast
.Num
):
320 return Rational(node
.n
)
321 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
322 return -cls
._fromast
(node
.operand
)
323 elif isinstance(node
, ast
.BinOp
):
324 left
= cls
._fromast
(node
.left
)
325 right
= cls
._fromast
(node
.right
)
326 if isinstance(node
.op
, ast
.Add
):
328 elif isinstance(node
.op
, ast
.Sub
):
330 elif isinstance(node
.op
, ast
.Mult
):
332 elif isinstance(node
.op
, ast
.Div
):
334 raise SyntaxError('invalid syntax')
336 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d]\w*|\()')
339 def fromstring(cls
, string
):
341 Create an expression from a string. Raise SyntaxError if the string is
342 not properly formatted.
344 # Add implicit multiplication operators, e.g. '5x' -> '5*x'.
345 string
= LinExpr
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
346 tree
= ast
.parse(string
, 'eval')
347 expr
= cls
._fromast
(tree
)
348 if not isinstance(expr
, cls
):
349 raise SyntaxError('invalid syntax')
354 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
358 elif coefficient
== -1:
359 string
+= '-' if i
== 0 else ' - '
361 string
+= '{}*'.format(coefficient
)
362 elif coefficient
> 0:
363 string
+= ' + {}*'.format(coefficient
)
365 string
+= ' - {}*'.format(-coefficient
)
366 string
+= '{}'.format(symbol
)
367 constant
= self
.constant
369 string
+= '{}'.format(constant
)
371 string
+= ' + {}'.format(constant
)
373 string
+= ' - {}'.format(-constant
)
376 def _repr_latex_(self
):
378 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
382 elif coefficient
== -1:
383 string
+= '-' if i
== 0 else ' - '
385 string
+= '{}'.format(coefficient
._repr
_latex
_().strip('$'))
386 elif coefficient
> 0:
387 string
+= ' + {}'.format(coefficient
._repr
_latex
_().strip('$'))
388 elif coefficient
< 0:
389 string
+= ' - {}'.format((-coefficient
)._repr
_latex
_().strip('$'))
390 string
+= '{}'.format(symbol
._repr
_latex
_().strip('$'))
391 constant
= self
.constant
393 string
+= '{}'.format(constant
._repr
_latex
_().strip('$'))
395 string
+= ' + {}'.format(constant
._repr
_latex
_().strip('$'))
397 string
+= ' - {}'.format((-constant
)._repr
_latex
_().strip('$'))
398 return '$${}$$'.format(string
)
400 def _parenstr(self
, always
=False):
402 if not always
and (self
.isconstant() or self
.issymbol()):
405 return '({})'.format(string
)
408 def fromsympy(cls
, expr
):
410 Create a linear expression from a SymPy expression. Raise TypeError is
411 the sympy expression is not linear.
416 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
417 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
418 if symbol
== sympy
.S
.One
:
419 constant
= coefficient
420 elif isinstance(symbol
, sympy
.Dummy
):
421 # We cannot properly convert dummy symbols with respect to
423 raise TypeError('cannot convert dummy symbols')
424 elif isinstance(symbol
, sympy
.Symbol
):
425 symbol
= Symbol(symbol
.name
)
426 coefficients
.append((symbol
, coefficient
))
428 raise TypeError('non-linear expression: {!r}'.format(expr
))
429 expr
= LinExpr(coefficients
, constant
)
430 if not isinstance(expr
, cls
):
431 raise TypeError('cannot convert to a {} instance'.format(cls
.__name
__))
436 Convert the linear expression to a SymPy expression.
440 for symbol
, coefficient
in self
.coefficients():
441 term
= coefficient
* sympy
.Symbol(symbol
.name
)
443 expr
+= self
.constant
447 class Symbol(LinExpr
):
449 Symbols are the basic components to build expressions and constraints.
450 They correspond to mathematical variables. Symbols are instances of
451 class LinExpr and inherit its functionalities.
453 Two instances of Symbol are equal if they have the same name.
463 def __new__(cls
, name
):
465 Return a symbol with the name string given in argument.
467 if not isinstance(name
, str):
468 raise TypeError('name must be a string')
469 node
= ast
.parse(name
)
471 name
= node
.body
[0].value
.id
472 except (AttributeError, SyntaxError):
473 raise SyntaxError('invalid syntax')
474 self
= object().__new
__(cls
)
476 self
._constant
= Fraction(0)
477 self
._symbols
= (self
,)
482 def _coefficients(self
):
483 # This is not implemented as an attribute, because __hash__ is not
484 # callable in __new__ in class Dummy.
485 return {self
: Fraction(1)}
490 The name of the symbol.
495 return hash(self
.sortkey())
499 Return a sorting key for the symbol. It is useful to sort a list of
500 symbols in a consistent order, as comparison functions are overridden
501 (see the documentation of class LinExpr).
503 >>> sort(symbols, key=Symbol.sortkey)
510 def __eq__(self
, other
):
511 if isinstance(other
, Symbol
):
512 return self
.sortkey() == other
.sortkey()
513 return NotImplemented
517 Return a new Dummy symbol instance with the same name.
519 return Dummy(self
.name
)
524 def _repr_latex_(self
):
525 return '$${}$$'.format(self
.name
)
530 This function returns a tuple of symbols whose names are taken from a comma
531 or whitespace delimited string, or a sequence of strings. It is useful to
532 define several symbols at once.
534 >>> x, y = symbols('x y')
535 >>> x, y = symbols('x, y')
536 >>> x, y = symbols(['x', 'y'])
538 if isinstance(names
, str):
539 names
= names
.replace(',', ' ').split()
540 return tuple(Symbol(name
) for name
in names
)
545 A variation of Symbol in which all symbols are unique and identified by
546 an internal count index. If a name is not supplied then a string value
547 of the count index will be used. This is useful when a unique, temporary
548 variable is needed and the name of the variable used in the expression
551 Unlike Symbol, Dummy instances with the same name are not equal:
554 >>> x1, x2 = Dummy('x'), Dummy('x')
565 def __new__(cls
, name
=None):
567 Return a fresh dummy symbol with the name string given in argument.
570 name
= 'Dummy_{}'.format(Dummy
._count
)
571 self
= super().__new
__(cls
, name
)
572 self
._index
= Dummy
._count
577 return hash(self
.sortkey())
580 return self
._name
, self
._index
583 return '_{}'.format(self
.name
)
585 def _repr_latex_(self
):
586 return '$${}_{{{}}}$$'.format(self
.name
, self
._index
)
589 class Rational(LinExpr
, Fraction
):
591 A particular case of linear expressions are rational values, i.e. linear
592 expressions consisting only of a constant term, with no symbol. They are
593 implemented by the Rational class, that inherits from both LinExpr and
594 fractions.Fraction classes.
602 ) + Fraction
.__slots
__
604 def __new__(cls
, numerator
=0, denominator
=None):
605 self
= object().__new
__(cls
)
606 self
._coefficients
= {}
607 self
._constant
= Fraction(numerator
, denominator
)
610 self
._numerator
= self
._constant
.numerator
611 self
._denominator
= self
._constant
.denominator
615 return Fraction
.__hash
__(self
)
621 def isconstant(self
):
625 return Fraction
.__bool
__(self
)
628 if self
.denominator
== 1:
629 return '{!r}'.format(self
.numerator
)
631 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)
633 def _repr_latex_(self
):
634 if self
.denominator
== 1:
635 return '$${}$$'.format(self
.numerator
)
636 elif self
.numerator
< 0:
637 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self
.numerator
,
640 return '$$\\frac{{{}}}{{{}}}$$'.format(self
.numerator
,