1 # Copyright 2014 MINES ParisTech
3 # This file is part of LinPy.
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
23 from collections
import defaultdict
, Mapping
, OrderedDict
24 from fractions
import Fraction
, gcd
36 def _polymorphic(func
):
37 @functools.wraps(func
)
38 def wrapper(left
, right
):
39 if isinstance(right
, LinExpr
):
40 return func(left
, right
)
41 elif isinstance(right
, numbers
.Rational
):
42 right
= Rational(right
)
43 return func(left
, right
)
50 A linear expression consists of a list of coefficient-variable pairs
51 that capture the linear terms, plus a constant term. Linear expressions
52 are used to build constraints. They are temporary objects that typically
55 Linear expressions are generally built using overloaded operators. For
56 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
58 LinExpr instances are hashable, and should be treated as immutable.
61 def __new__(cls
, coefficients
=None, constant
=0):
63 Return a linear expression from a dictionary or a sequence, that maps
64 symbols to their coefficients, and a constant term. The coefficients
65 and the constant term must be rational numbers.
67 For example, the linear expression x + 2*y + 1 can be constructed using
68 one of the following instructions:
70 >>> x, y = symbols('x y')
71 >>> LinExpr({x: 1, y: 2}, 1)
72 >>> LinExpr([(x, 1), (y, 2)], 1)
74 However, it may be easier to use overloaded operators:
76 >>> x, y = symbols('x y')
79 Alternatively, linear expressions can be constructed from a string:
81 >>> LinExpr('x + 2y + 1')
83 A linear expression with a single symbol of coefficient 1 and no
84 constant term is automatically subclassed as a Symbol instance. A
85 linear expression with no symbol, only a constant term, is
86 automatically subclassed as a Rational instance.
88 if isinstance(coefficients
, str):
90 raise TypeError('too many arguments')
91 return LinExpr
.fromstring(coefficients
)
92 if coefficients
is None:
93 return Rational(constant
)
94 if isinstance(coefficients
, Mapping
):
95 coefficients
= coefficients
.items()
96 coefficients
= list(coefficients
)
97 for symbol
, coefficient
in coefficients
:
98 if not isinstance(symbol
, Symbol
):
99 raise TypeError('symbols must be Symbol instances')
100 if not isinstance(coefficient
, numbers
.Rational
):
101 raise TypeError('coefficients must be rational numbers')
102 if not isinstance(constant
, numbers
.Rational
):
103 raise TypeError('constant must be a rational number')
104 if len(coefficients
) == 0:
105 return Rational(constant
)
106 if len(coefficients
) == 1 and constant
== 0:
107 symbol
, coefficient
= coefficients
[0]
110 coefficients
= [(symbol_
, Fraction(coefficient_
))
111 for symbol_
, coefficient_
in coefficients
112 if coefficient_
!= 0]
113 coefficients
.sort(key
=lambda item
: item
[0].sortkey())
114 self
= object().__new
__(cls
)
115 self
._coefficients
= OrderedDict(coefficients
)
116 self
._constant
= Fraction(constant
)
117 self
._symbols
= tuple(self
._coefficients
)
118 self
._dimension
= len(self
._symbols
)
121 def coefficient(self
, symbol
):
123 Return the coefficient value of the given symbol, or 0 if the symbol
124 does not appear in the expression.
126 if not isinstance(symbol
, Symbol
):
127 raise TypeError('symbol must be a Symbol instance')
128 return self
._coefficients
.get(symbol
, Fraction(0))
130 __getitem__
= coefficient
132 def coefficients(self
):
134 Iterate over the pairs (symbol, value) of linear terms in the
135 expression. The constant term is ignored.
137 yield from self
._coefficients
.items()
142 The constant term of the expression.
144 return self
._constant
149 The tuple of symbols present in the expression, sorted according to
157 The dimension of the expression, i.e. the number of symbols present in
160 return self
._dimension
163 return hash((tuple(self
._coefficients
.items()), self
._constant
))
165 def isconstant(self
):
167 Return True if the expression only consists of a constant term. In this
168 case, it is a Rational instance.
174 Return True if an expression only consists of a symbol with coefficient
175 1. In this case, it is a Symbol instance.
181 Iterate over the coefficient values in the expression, and the constant
184 yield from self
._coefficients
.values()
197 def __add__(self
, other
):
199 Return the sum of two linear expressions.
201 coefficients
= defaultdict(Fraction
, self
._coefficients
)
202 for symbol
, coefficient
in other
._coefficients
.items():
203 coefficients
[symbol
] += coefficient
204 constant
= self
._constant
+ other
._constant
205 return LinExpr(coefficients
, constant
)
210 def __sub__(self
, other
):
212 Return the difference between two linear expressions.
214 coefficients
= defaultdict(Fraction
, self
._coefficients
)
215 for symbol
, coefficient
in other
._coefficients
.items():
216 coefficients
[symbol
] -= coefficient
217 constant
= self
._constant
- other
._constant
218 return LinExpr(coefficients
, constant
)
221 def __rsub__(self
, other
):
224 def __mul__(self
, other
):
226 Return the product of the linear expression by a rational.
228 if isinstance(other
, numbers
.Rational
):
230 (symbol
, coefficient
* other
)
231 for symbol
, coefficient
in self
._coefficients
.items())
232 constant
= self
._constant
* other
233 return LinExpr(coefficients
, constant
)
234 return NotImplemented
238 def __truediv__(self
, other
):
240 Return the quotient of the linear expression by a rational.
242 if isinstance(other
, numbers
.Rational
):
244 (symbol
, coefficient
/ other
)
245 for symbol
, coefficient
in self
._coefficients
.items())
246 constant
= self
._constant
/ other
247 return LinExpr(coefficients
, constant
)
248 return NotImplemented
251 def __eq__(self
, other
):
253 Test whether two linear expressions are equal. Unlike methods
254 LinExpr.__lt__(), LinExpr.__le__(), LinExpr.__ge__(), LinExpr.__gt__(),
255 the result is a boolean value, not a polyhedron. To express that two
256 linear expressions are equal or not equal, use functions Eq() and Ne()
259 return self
._coefficients
== other
._coefficients
and \
260 self
._constant
== other
._constant
263 def __lt__(self
, other
):
264 from .polyhedra
import Polyhedron
265 return Polyhedron([], [other
- self
- 1])
268 def __le__(self
, other
):
269 from .polyhedra
import Polyhedron
270 return Polyhedron([], [other
- self
])
273 def __ge__(self
, other
):
274 from .polyhedra
import Polyhedron
275 return Polyhedron([], [self
- other
])
278 def __gt__(self
, other
):
279 from .polyhedra
import Polyhedron
280 return Polyhedron([], [self
- other
- 1])
284 Return the expression multiplied by its lowest common denominator to
285 make all values integer.
287 lcd
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
288 [value
.denominator
for value
in self
.values()])
291 def subs(self
, symbol
, expression
=None):
293 Substitute the given symbol by an expression and return the resulting
294 expression. Raise TypeError if the resulting expression is not linear.
296 >>> x, y = symbols('x y')
301 To perform multiple substitutions at once, pass a sequence or a
302 dictionary of (old, new) pairs to subs.
304 >>> e.subs({x: y, y: x})
307 if expression
is None:
308 substitutions
= dict(symbol
)
310 substitutions
= {symbol
: expression
}
311 for symbol
in substitutions
:
312 if not isinstance(symbol
, Symbol
):
313 raise TypeError('symbols must be Symbol instances')
314 result
= Rational(self
._constant
)
315 for symbol
, coefficient
in self
._coefficients
.items():
316 expression
= substitutions
.get(symbol
, symbol
)
317 result
+= coefficient
* expression
321 def _fromast(cls
, node
):
322 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
323 return cls
._fromast
(node
.body
[0])
324 elif isinstance(node
, ast
.Expr
):
325 return cls
._fromast
(node
.value
)
326 elif isinstance(node
, ast
.Name
):
327 return Symbol(node
.id)
328 elif isinstance(node
, ast
.Num
):
329 return Rational(node
.n
)
330 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
331 return -cls
._fromast
(node
.operand
)
332 elif isinstance(node
, ast
.BinOp
):
333 left
= cls
._fromast
(node
.left
)
334 right
= cls
._fromast
(node
.right
)
335 if isinstance(node
.op
, ast
.Add
):
337 elif isinstance(node
.op
, ast
.Sub
):
339 elif isinstance(node
.op
, ast
.Mult
):
341 elif isinstance(node
.op
, ast
.Div
):
343 raise SyntaxError('invalid syntax')
345 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d]\w*|\()')
348 def fromstring(cls
, string
):
350 Create an expression from a string. Raise SyntaxError if the string is
351 not properly formatted.
353 # Add implicit multiplication operators, e.g. '5x' -> '5*x'.
354 string
= LinExpr
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
355 tree
= ast
.parse(string
, 'eval')
356 expression
= cls
._fromast
(tree
)
357 if not isinstance(expression
, cls
):
358 raise SyntaxError('invalid syntax')
363 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
367 elif coefficient
== -1:
368 string
+= '-' if i
== 0 else ' - '
370 string
+= '{}*'.format(coefficient
)
371 elif coefficient
> 0:
372 string
+= ' + {}*'.format(coefficient
)
374 string
+= ' - {}*'.format(-coefficient
)
375 string
+= '{}'.format(symbol
)
376 constant
= self
.constant
378 string
+= '{}'.format(constant
)
380 string
+= ' + {}'.format(constant
)
382 string
+= ' - {}'.format(-constant
)
385 def _parenstr(self
, always
=False):
387 if not always
and (self
.isconstant() or self
.issymbol()):
390 return '({})'.format(string
)
393 def fromsympy(cls
, expression
):
395 Create a linear expression from a SymPy expression. Raise TypeError is
396 the sympy expression is not linear.
401 for symbol
, coefficient
in expression
.as_coefficients_dict().items():
402 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
403 if symbol
== sympy
.S
.One
:
404 constant
= coefficient
405 elif isinstance(symbol
, sympy
.Dummy
):
406 # We cannot properly convert dummy symbols with respect to
408 raise TypeError('cannot convert dummy symbols')
409 elif isinstance(symbol
, sympy
.Symbol
):
410 symbol
= Symbol(symbol
.name
)
411 coefficients
.append((symbol
, coefficient
))
413 raise TypeError('non-linear expression: {!r}'.format(
415 expression
= LinExpr(coefficients
, constant
)
416 if not isinstance(expression
, cls
):
417 raise TypeError('cannot convert to a {} instance'.format(
423 Convert the linear expression to a SymPy expression.
427 for symbol
, coefficient
in self
.coefficients():
428 term
= coefficient
* sympy
.Symbol(symbol
.name
)
430 expression
+= self
.constant
434 class Symbol(LinExpr
):
436 Symbols are the basic components to build expressions and constraints.
437 They correspond to mathematical variables. Symbols are instances of
438 class LinExpr and inherit its functionalities.
440 Two instances of Symbol are equal if they have the same name.
450 def __new__(cls
, name
):
452 Return a symbol with the name string given in argument.
454 if not isinstance(name
, str):
455 raise TypeError('name must be a string')
456 node
= ast
.parse(name
)
458 name
= node
.body
[0].value
.id
459 except (AttributeError, SyntaxError):
460 raise SyntaxError('invalid syntax')
461 self
= object().__new
__(cls
)
463 self
._constant
= Fraction(0)
464 self
._symbols
= (self
,)
469 def _coefficients(self
):
470 # This is not implemented as an attribute, because __hash__ is not
471 # callable in __new__ in class Dummy.
472 return {self
: Fraction(1)}
477 The name of the symbol.
482 return hash(self
.sortkey())
486 Return a sorting key for the symbol. It is useful to sort a list of
487 symbols in a consistent order, as comparison functions are overridden
488 (see the documentation of class LinExpr).
490 >>> sort(symbols, key=Symbol.sortkey)
497 def __eq__(self
, other
):
498 if isinstance(other
, Symbol
):
499 return self
.sortkey() == other
.sortkey()
500 return NotImplemented
504 Return a new Dummy symbol instance with the same name.
506 return Dummy(self
.name
)
514 This function returns a tuple of symbols whose names are taken from a comma
515 or whitespace delimited string, or a sequence of strings. It is useful to
516 define several symbols at once.
518 >>> x, y = symbols('x y')
519 >>> x, y = symbols('x, y')
520 >>> x, y = symbols(['x', 'y'])
522 if isinstance(names
, str):
523 names
= names
.replace(',', ' ').split()
524 return tuple(Symbol(name
) for name
in names
)
529 A variation of Symbol in which all symbols are unique and identified by
530 an internal count index. If a name is not supplied then a string value
531 of the count index will be used. This is useful when a unique, temporary
532 variable is needed and the name of the variable used in the expression
535 Unlike Symbol, Dummy instances with the same name are not equal:
538 >>> x1, x2 = Dummy('x'), Dummy('x')
549 def __new__(cls
, name
=None):
551 Return a fresh dummy symbol with the name string given in argument.
554 name
= 'Dummy_{}'.format(Dummy
._count
)
555 self
= super().__new
__(cls
, name
)
556 self
._index
= Dummy
._count
561 return hash(self
.sortkey())
564 return self
._name
, self
._index
567 return '_{}'.format(self
.name
)
570 class Rational(LinExpr
, Fraction
):
572 A particular case of linear expressions are rational values, i.e. linear
573 expressions consisting only of a constant term, with no symbol. They are
574 implemented by the Rational class, that inherits from both LinExpr and
575 fractions.Fraction classes.
583 ) + Fraction
.__slots
__
585 def __new__(cls
, numerator
=0, denominator
=None):
586 self
= object().__new
__(cls
)
587 self
._coefficients
= {}
588 self
._constant
= Fraction(numerator
, denominator
)
591 self
._numerator
= self
._constant
.numerator
592 self
._denominator
= self
._constant
.denominator
596 return Fraction
.__hash
__(self
)
602 def isconstant(self
):
606 return Fraction
.__bool
__(self
)
609 if self
.denominator
== 1:
610 return '{!r}'.format(self
.numerator
)
612 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)