Fix LinExpr.subs() implementation
[linpy.git] / linpy / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of LinPy.
4 #
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import OrderedDict, defaultdict, Mapping
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'LinExpr',
29 'Symbol', 'Dummy', 'symbols',
30 'Rational',
31 ]
32
33
34 def _polymorphic(func):
35 @functools.wraps(func)
36 def wrapper(left, right):
37 if isinstance(right, LinExpr):
38 return func(left, right)
39 elif isinstance(right, numbers.Rational):
40 right = Rational(right)
41 return func(left, right)
42 return NotImplemented
43 return wrapper
44
45
46 class LinExpr:
47 """
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
51 have short lifespans.
52
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
55
56 LinExpr instances are hashable, and should be treated as immutable.
57 """
58
59 def __new__(cls, coefficients=None, constant=0):
60 """
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
64
65 For example, the linear expression x + 2y + 1 can be constructed using
66 one of the following instructions:
67
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
71
72 However, it may be easier to use overloaded operators:
73
74 >>> x, y = symbols('x y')
75 >>> x + 2*y + 1
76
77 Alternatively, linear expressions can be constructed from a string:
78
79 >>> LinExpr('x + 2*y + 1')
80
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
85 """
86 if isinstance(coefficients, str):
87 if constant != 0:
88 raise TypeError('too many arguments')
89 return LinExpr.fromstring(coefficients)
90 if coefficients is None:
91 return Rational(constant)
92 if isinstance(coefficients, Mapping):
93 coefficients = coefficients.items()
94 coefficients = list(coefficients)
95 for symbol, coefficient in coefficients:
96 if not isinstance(symbol, Symbol):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient, numbers.Rational):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant, numbers.Rational):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients) == 0:
103 return Rational(constant)
104 if len(coefficients) == 1 and constant == 0:
105 symbol, coefficient = coefficients[0]
106 if coefficient == 1:
107 return symbol
108 coefficients = [(symbol, Fraction(coefficient))
109 for symbol, coefficient in coefficients if coefficient != 0]
110 coefficients.sort(key=lambda item: item[0].sortkey())
111 self = object().__new__(cls)
112 self._coefficients = OrderedDict(coefficients)
113 self._constant = Fraction(constant)
114 self._symbols = tuple(self._coefficients)
115 self._dimension = len(self._symbols)
116 return self
117
118 def coefficient(self, symbol):
119 """
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
122 """
123 if not isinstance(symbol, Symbol):
124 raise TypeError('symbol must be a Symbol instance')
125 return self._coefficients.get(symbol, Fraction(0))
126
127 __getitem__ = coefficient
128
129 def coefficients(self):
130 """
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
133 """
134 yield from self._coefficients.items()
135
136 @property
137 def constant(self):
138 """
139 The constant term of the expression.
140 """
141 return self._constant
142
143 @property
144 def symbols(self):
145 """
146 The tuple of symbols present in the expression, sorted according to
147 Symbol.sortkey().
148 """
149 return self._symbols
150
151 @property
152 def dimension(self):
153 """
154 The dimension of the expression, i.e. the number of symbols present in
155 it.
156 """
157 return self._dimension
158
159 def __hash__(self):
160 return hash((tuple(self._coefficients.items()), self._constant))
161
162 def isconstant(self):
163 """
164 Return True if the expression only consists of a constant term. In this
165 case, it is a Rational instance.
166 """
167 return False
168
169 def issymbol(self):
170 """
171 Return True if an expression only consists of a symbol with coefficient
172 1. In this case, it is a Symbol instance.
173 """
174 return False
175
176 def values(self):
177 """
178 Iterate over the coefficient values in the expression, and the constant
179 term.
180 """
181 yield from self._coefficients.values()
182 yield self._constant
183
184 def __bool__(self):
185 return True
186
187 def __pos__(self):
188 return self
189
190 def __neg__(self):
191 return self * -1
192
193 @_polymorphic
194 def __add__(self, other):
195 """
196 Return the sum of two linear expressions.
197 """
198 coefficients = defaultdict(Fraction, self._coefficients)
199 for symbol, coefficient in other._coefficients.items():
200 coefficients[symbol] += coefficient
201 constant = self._constant + other._constant
202 return LinExpr(coefficients, constant)
203
204 __radd__ = __add__
205
206 @_polymorphic
207 def __sub__(self, other):
208 """
209 Return the difference between two linear expressions.
210 """
211 coefficients = defaultdict(Fraction, self._coefficients)
212 for symbol, coefficient in other._coefficients.items():
213 coefficients[symbol] -= coefficient
214 constant = self._constant - other._constant
215 return LinExpr(coefficients, constant)
216
217 @_polymorphic
218 def __rsub__(self, other):
219 return other - self
220
221 def __mul__(self, other):
222 """
223 Return the product of the linear expression by a rational.
224 """
225 if isinstance(other, numbers.Rational):
226 coefficients = ((symbol, coefficient * other)
227 for symbol, coefficient in self._coefficients.items())
228 constant = self._constant * other
229 return LinExpr(coefficients, constant)
230 return NotImplemented
231
232 __rmul__ = __mul__
233
234 def __truediv__(self, other):
235 """
236 Return the quotient of the linear expression by a rational.
237 """
238 if isinstance(other, numbers.Rational):
239 coefficients = ((symbol, coefficient / other)
240 for symbol, coefficient in self._coefficients.items())
241 constant = self._constant / other
242 return LinExpr(coefficients, constant)
243 return NotImplemented
244
245 @_polymorphic
246 def __eq__(self, other):
247 """
248 Test whether two linear expressions are equal.
249 """
250 if isinstance(other, LinExpr):
251 return self._coefficients == other._coefficients and \
252 self._constant == other._constant
253 return NotImplemented
254
255 def __le__(self, other):
256 from .polyhedra import Le
257 return Le(self, other)
258
259 def __lt__(self, other):
260 from .polyhedra import Lt
261 return Lt(self, other)
262
263 def __ge__(self, other):
264 from .polyhedra import Ge
265 return Ge(self, other)
266
267 def __gt__(self, other):
268 from .polyhedra import Gt
269 return Gt(self, other)
270
271 def scaleint(self):
272 """
273 Return the expression multiplied by its lowest common denominator to
274 make all values integer.
275 """
276 lcd = functools.reduce(lambda a, b: a*b // gcd(a, b),
277 [value.denominator for value in self.values()])
278 return self * lcd
279
280 def subs(self, symbol, expression=None):
281 """
282 Substitute the given symbol by an expression and return the resulting
283 expression. Raise TypeError if the resulting expression is not linear.
284
285 >>> x, y = symbols('x y')
286 >>> e = x + 2*y + 1
287 >>> e.subs(y, x - 1)
288 3*x - 1
289
290 To perform multiple substitutions at once, pass a sequence or a
291 dictionary of (old, new) pairs to subs.
292
293 >>> e.subs({x: y, y: x})
294 2*x + y + 1
295 """
296 if expression is None:
297 substitutions = dict(symbol)
298 else:
299 substitutions = {symbol: expression}
300 for symbol in substitutions:
301 if not isinstance(symbol, Symbol):
302 raise TypeError('symbols must be Symbol instances')
303 result = self._constant
304 for symbol, coefficient in self._coefficients.items():
305 expression = substitutions.get(symbol, symbol)
306 result += coefficient * expression
307 return result
308
309 @classmethod
310 def _fromast(cls, node):
311 if isinstance(node, ast.Module) and len(node.body) == 1:
312 return cls._fromast(node.body[0])
313 elif isinstance(node, ast.Expr):
314 return cls._fromast(node.value)
315 elif isinstance(node, ast.Name):
316 return Symbol(node.id)
317 elif isinstance(node, ast.Num):
318 return Rational(node.n)
319 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
320 return -cls._fromast(node.operand)
321 elif isinstance(node, ast.BinOp):
322 left = cls._fromast(node.left)
323 right = cls._fromast(node.right)
324 if isinstance(node.op, ast.Add):
325 return left + right
326 elif isinstance(node.op, ast.Sub):
327 return left - right
328 elif isinstance(node.op, ast.Mult):
329 return left * right
330 elif isinstance(node.op, ast.Div):
331 return left / right
332 raise SyntaxError('invalid syntax')
333
334 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
335
336 @classmethod
337 def fromstring(cls, string):
338 """
339 Create an expression from a string. Raise SyntaxError if the string is
340 not properly formatted.
341 """
342 # add implicit multiplication operators, e.g. '5x' -> '5*x'
343 string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
344 tree = ast.parse(string, 'eval')
345 expr = cls._fromast(tree)
346 if not isinstance(expr, cls):
347 raise SyntaxError('invalid syntax')
348 return expr
349
350 def __repr__(self):
351 string = ''
352 for i, (symbol, coefficient) in enumerate(self.coefficients()):
353 if coefficient == 1:
354 if i != 0:
355 string += ' + '
356 elif coefficient == -1:
357 string += '-' if i == 0 else ' - '
358 elif i == 0:
359 string += '{}*'.format(coefficient)
360 elif coefficient > 0:
361 string += ' + {}*'.format(coefficient)
362 else:
363 string += ' - {}*'.format(-coefficient)
364 string += '{}'.format(symbol)
365 constant = self.constant
366 if len(string) == 0:
367 string += '{}'.format(constant)
368 elif constant > 0:
369 string += ' + {}'.format(constant)
370 elif constant < 0:
371 string += ' - {}'.format(-constant)
372 return string
373
374 def _repr_latex_(self):
375 string = ''
376 for i, (symbol, coefficient) in enumerate(self.coefficients()):
377 if coefficient == 1:
378 if i != 0:
379 string += ' + '
380 elif coefficient == -1:
381 string += '-' if i == 0 else ' - '
382 elif i == 0:
383 string += '{}'.format(coefficient._repr_latex_().strip('$'))
384 elif coefficient > 0:
385 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
386 elif coefficient < 0:
387 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
388 string += '{}'.format(symbol._repr_latex_().strip('$'))
389 constant = self.constant
390 if len(string) == 0:
391 string += '{}'.format(constant._repr_latex_().strip('$'))
392 elif constant > 0:
393 string += ' + {}'.format(constant._repr_latex_().strip('$'))
394 elif constant < 0:
395 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
396 return '$${}$$'.format(string)
397
398 def _parenstr(self, always=False):
399 string = str(self)
400 if not always and (self.isconstant() or self.issymbol()):
401 return string
402 else:
403 return '({})'.format(string)
404
405 @classmethod
406 def fromsympy(cls, expr):
407 """
408 Create a linear expression from a sympy expression. Raise TypeError is
409 the sympy expression is not linear.
410 """
411 import sympy
412 coefficients = []
413 constant = 0
414 for symbol, coefficient in expr.as_coefficients_dict().items():
415 coefficient = Fraction(coefficient.p, coefficient.q)
416 if symbol == sympy.S.One:
417 constant = coefficient
418 elif isinstance(symbol, sympy.Dummy):
419 # we cannot properly convert dummy symbols
420 raise TypeError('cannot convert dummy symbols')
421 elif isinstance(symbol, sympy.Symbol):
422 symbol = Symbol(symbol.name)
423 coefficients.append((symbol, coefficient))
424 else:
425 raise TypeError('non-linear expression: {!r}'.format(expr))
426 expr = LinExpr(coefficients, constant)
427 if not isinstance(expr, cls):
428 raise TypeError('cannot convert to a {} instance'.format(cls.__name__))
429 return expr
430
431 def tosympy(self):
432 """
433 Convert the linear expression to a sympy expression.
434 """
435 import sympy
436 expr = 0
437 for symbol, coefficient in self.coefficients():
438 term = coefficient * sympy.Symbol(symbol.name)
439 expr += term
440 expr += self.constant
441 return expr
442
443
444 class Symbol(LinExpr):
445 """
446 Symbols are the basic components to build expressions and constraints.
447 They correspond to mathematical variables. Symbols are instances of
448 class LinExpr and inherit its functionalities.
449
450 Two instances of Symbol are equal if they have the same name.
451 """
452
453 def __new__(cls, name):
454 """
455 Return a symbol with the name string given in argument.
456 """
457 if not isinstance(name, str):
458 raise TypeError('name must be a string')
459 node = ast.parse(name)
460 try:
461 name = node.body[0].value.id
462 except (AttributeError, SyntaxError):
463 raise SyntaxError('invalid syntax')
464 self = object().__new__(cls)
465 self._name = name
466 self._coefficients = {self: Fraction(1)}
467 self._constant = Fraction(0)
468 self._symbols = (self,)
469 self._dimension = 1
470 return self
471
472 @property
473 def name(self):
474 """
475 The name of the symbol.
476 """
477 return self._name
478
479 def __hash__(self):
480 return hash(self.sortkey())
481
482 def sortkey(self):
483 """
484 Return a sorting key for the symbol. It is useful to sort a list of
485 symbols in a consistent order, as comparison functions are overridden
486 (see the documentation of class LinExpr).
487
488 >>> sort(symbols, key=Symbol.sortkey)
489 """
490 return self.name,
491
492 def issymbol(self):
493 return True
494
495 def __eq__(self, other):
496 if isinstance(other, Symbol):
497 return self.sortkey() == other.sortkey()
498 return NotImplemented
499
500 def asdummy(self):
501 """
502 Return a new Dummy symbol instance with the same name.
503 """
504 return Dummy(self.name)
505
506 def __repr__(self):
507 return self.name
508
509 def _repr_latex_(self):
510 return '$${}$$'.format(self.name)
511
512
513 def symbols(names):
514 """
515 This function returns a tuple of symbols whose names are taken from a comma
516 or whitespace delimited string, or a sequence of strings. It is useful to
517 define several symbols at once.
518
519 >>> x, y = symbols('x y')
520 >>> x, y = symbols('x, y')
521 >>> x, y = symbols(['x', 'y'])
522 """
523 if isinstance(names, str):
524 names = names.replace(',', ' ').split()
525 return tuple(Symbol(name) for name in names)
526
527
528 class Dummy(Symbol):
529 """
530 A variation of Symbol in which all symbols are unique and identified by
531 an internal count index. If a name is not supplied then a string value
532 of the count index will be used. This is useful when a unique, temporary
533 variable is needed and the name of the variable used in the expression
534 is not important.
535
536 Unlike Symbol, Dummy instances with the same name are not equal:
537
538 >>> x = Symbol('x')
539 >>> x1, x2 = Dummy('x'), Dummy('x')
540 >>> x == x1
541 False
542 >>> x1 == x2
543 False
544 >>> x1 == x1
545 True
546 """
547
548 _count = 0
549
550 def __new__(cls, name=None):
551 """
552 Return a fresh dummy symbol with the name string given in argument.
553 """
554 if name is None:
555 name = 'Dummy_{}'.format(Dummy._count)
556 elif not isinstance(name, str):
557 raise TypeError('name must be a string')
558 self = object().__new__(cls)
559 self._index = Dummy._count
560 self._name = name.strip()
561 self._coefficients = {self: Fraction(1)}
562 self._constant = Fraction(0)
563 self._symbols = (self,)
564 self._dimension = 1
565 Dummy._count += 1
566 return self
567
568 def __hash__(self):
569 return hash(self.sortkey())
570
571 def sortkey(self):
572 return self._name, self._index
573
574 def __repr__(self):
575 return '_{}'.format(self.name)
576
577 def _repr_latex_(self):
578 return '$${}_{{{}}}$$'.format(self.name, self._index)
579
580
581 class Rational(LinExpr, Fraction):
582 """
583 A particular case of linear expressions are rational values, i.e. linear
584 expressions consisting only of a constant term, with no symbol. They are
585 implemented by the Rational class, that inherits from both LinExpr and
586 fractions.Fraction classes.
587 """
588
589 def __new__(cls, numerator=0, denominator=None):
590 self = object().__new__(cls)
591 self._coefficients = {}
592 self._constant = Fraction(numerator, denominator)
593 self._symbols = ()
594 self._dimension = 0
595 self._numerator = self._constant.numerator
596 self._denominator = self._constant.denominator
597 return self
598
599 def __hash__(self):
600 return Fraction.__hash__(self)
601
602 @property
603 def constant(self):
604 return self
605
606 def isconstant(self):
607 return True
608
609 def __bool__(self):
610 return Fraction.__bool__(self)
611
612 def __repr__(self):
613 if self.denominator == 1:
614 return '{!r}'.format(self.numerator)
615 else:
616 return '{!r}/{!r}'.format(self.numerator, self.denominator)
617
618 def _repr_latex_(self):
619 if self.denominator == 1:
620 return '$${}$$'.format(self.numerator)
621 elif self.numerator < 0:
622 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
623 self.denominator)
624 else:
625 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
626 self.denominator)