Improve comparison methods in LinExpr
[linpy.git] / linpy / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of LinPy.
4 #
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import OrderedDict, defaultdict, Mapping
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'LinExpr',
29 'Symbol', 'Dummy', 'symbols',
30 'Rational',
31 ]
32
33
34 def _polymorphic(func):
35 @functools.wraps(func)
36 def wrapper(left, right):
37 if isinstance(right, LinExpr):
38 return func(left, right)
39 elif isinstance(right, numbers.Rational):
40 right = Rational(right)
41 return func(left, right)
42 return NotImplemented
43 return wrapper
44
45
46 class LinExpr:
47 """
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
51 have short lifespans.
52
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
55
56 LinExpr instances are hashable, and should be treated as immutable.
57 """
58
59 def __new__(cls, coefficients=None, constant=0):
60 """
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
64
65 For example, the linear expression x + 2y + 1 can be constructed using
66 one of the following instructions:
67
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
71
72 However, it may be easier to use overloaded operators:
73
74 >>> x, y = symbols('x y')
75 >>> x + 2*y + 1
76
77 Alternatively, linear expressions can be constructed from a string:
78
79 >>> LinExpr('x + 2*y + 1')
80
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
85 """
86 if isinstance(coefficients, str):
87 if constant != 0:
88 raise TypeError('too many arguments')
89 return LinExpr.fromstring(coefficients)
90 if coefficients is None:
91 return Rational(constant)
92 if isinstance(coefficients, Mapping):
93 coefficients = coefficients.items()
94 coefficients = list(coefficients)
95 for symbol, coefficient in coefficients:
96 if not isinstance(symbol, Symbol):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient, numbers.Rational):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant, numbers.Rational):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients) == 0:
103 return Rational(constant)
104 if len(coefficients) == 1 and constant == 0:
105 symbol, coefficient = coefficients[0]
106 if coefficient == 1:
107 return symbol
108 coefficients = [(symbol, Fraction(coefficient))
109 for symbol, coefficient in coefficients if coefficient != 0]
110 coefficients.sort(key=lambda item: item[0].sortkey())
111 self = object().__new__(cls)
112 self._coefficients = OrderedDict(coefficients)
113 self._constant = Fraction(constant)
114 self._symbols = tuple(self._coefficients)
115 self._dimension = len(self._symbols)
116 return self
117
118 def coefficient(self, symbol):
119 """
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
122 """
123 if not isinstance(symbol, Symbol):
124 raise TypeError('symbol must be a Symbol instance')
125 return self._coefficients.get(symbol, Fraction(0))
126
127 __getitem__ = coefficient
128
129 def coefficients(self):
130 """
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
133 """
134 yield from self._coefficients.items()
135
136 @property
137 def constant(self):
138 """
139 The constant term of the expression.
140 """
141 return self._constant
142
143 @property
144 def symbols(self):
145 """
146 The tuple of symbols present in the expression, sorted according to
147 Symbol.sortkey().
148 """
149 return self._symbols
150
151 @property
152 def dimension(self):
153 """
154 The dimension of the expression, i.e. the number of symbols present in
155 it.
156 """
157 return self._dimension
158
159 def __hash__(self):
160 return hash((tuple(self._coefficients.items()), self._constant))
161
162 def isconstant(self):
163 """
164 Return True if the expression only consists of a constant term. In this
165 case, it is a Rational instance.
166 """
167 return False
168
169 def issymbol(self):
170 """
171 Return True if an expression only consists of a symbol with coefficient
172 1. In this case, it is a Symbol instance.
173 """
174 return False
175
176 def values(self):
177 """
178 Iterate over the coefficient values in the expression, and the constant
179 term.
180 """
181 yield from self._coefficients.values()
182 yield self._constant
183
184 def __bool__(self):
185 return True
186
187 def __pos__(self):
188 return self
189
190 def __neg__(self):
191 return self * -1
192
193 @_polymorphic
194 def __add__(self, other):
195 """
196 Return the sum of two linear expressions.
197 """
198 coefficients = defaultdict(Fraction, self._coefficients)
199 for symbol, coefficient in other._coefficients.items():
200 coefficients[symbol] += coefficient
201 constant = self._constant + other._constant
202 return LinExpr(coefficients, constant)
203
204 __radd__ = __add__
205
206 @_polymorphic
207 def __sub__(self, other):
208 """
209 Return the difference between two linear expressions.
210 """
211 coefficients = defaultdict(Fraction, self._coefficients)
212 for symbol, coefficient in other._coefficients.items():
213 coefficients[symbol] -= coefficient
214 constant = self._constant - other._constant
215 return LinExpr(coefficients, constant)
216
217 @_polymorphic
218 def __rsub__(self, other):
219 return other - self
220
221 def __mul__(self, other):
222 """
223 Return the product of the linear expression by a rational.
224 """
225 if isinstance(other, numbers.Rational):
226 coefficients = ((symbol, coefficient * other)
227 for symbol, coefficient in self._coefficients.items())
228 constant = self._constant * other
229 return LinExpr(coefficients, constant)
230 return NotImplemented
231
232 __rmul__ = __mul__
233
234 def __truediv__(self, other):
235 """
236 Return the quotient of the linear expression by a rational.
237 """
238 if isinstance(other, numbers.Rational):
239 coefficients = ((symbol, coefficient / other)
240 for symbol, coefficient in self._coefficients.items())
241 constant = self._constant / other
242 return LinExpr(coefficients, constant)
243 return NotImplemented
244
245 @_polymorphic
246 def __eq__(self, other):
247 """
248 Test whether two linear expressions are equal.
249 """
250 return self._coefficients == other._coefficients and \
251 self._constant == other._constant
252
253 @_polymorphic
254 def __lt__(self, other):
255 from .polyhedra import Polyhedron
256 return Polyhedron([], [other - self - 1])
257
258 @_polymorphic
259 def __le__(self, other):
260 from .polyhedra import Polyhedron
261 return Polyhedron([], [other - self])
262
263 @_polymorphic
264 def __ge__(self, other):
265 from .polyhedra import Polyhedron
266 return Polyhedron([], [self - other])
267
268 @_polymorphic
269 def __gt__(self, other):
270 from .polyhedra import Polyhedron
271 return Polyhedron([], [self - other - 1])
272
273 def scaleint(self):
274 """
275 Return the expression multiplied by its lowest common denominator to
276 make all values integer.
277 """
278 lcd = functools.reduce(lambda a, b: a*b // gcd(a, b),
279 [value.denominator for value in self.values()])
280 return self * lcd
281
282 def subs(self, symbol, expression=None):
283 """
284 Substitute the given symbol by an expression and return the resulting
285 expression. Raise TypeError if the resulting expression is not linear.
286
287 >>> x, y = symbols('x y')
288 >>> e = x + 2*y + 1
289 >>> e.subs(y, x - 1)
290 3*x - 1
291
292 To perform multiple substitutions at once, pass a sequence or a
293 dictionary of (old, new) pairs to subs.
294
295 >>> e.subs({x: y, y: x})
296 2*x + y + 1
297 """
298 if expression is None:
299 substitutions = dict(symbol)
300 else:
301 substitutions = {symbol: expression}
302 for symbol in substitutions:
303 if not isinstance(symbol, Symbol):
304 raise TypeError('symbols must be Symbol instances')
305 result = self._constant
306 for symbol, coefficient in self._coefficients.items():
307 expression = substitutions.get(symbol, symbol)
308 result += coefficient * expression
309 return result
310
311 @classmethod
312 def _fromast(cls, node):
313 if isinstance(node, ast.Module) and len(node.body) == 1:
314 return cls._fromast(node.body[0])
315 elif isinstance(node, ast.Expr):
316 return cls._fromast(node.value)
317 elif isinstance(node, ast.Name):
318 return Symbol(node.id)
319 elif isinstance(node, ast.Num):
320 return Rational(node.n)
321 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
322 return -cls._fromast(node.operand)
323 elif isinstance(node, ast.BinOp):
324 left = cls._fromast(node.left)
325 right = cls._fromast(node.right)
326 if isinstance(node.op, ast.Add):
327 return left + right
328 elif isinstance(node.op, ast.Sub):
329 return left - right
330 elif isinstance(node.op, ast.Mult):
331 return left * right
332 elif isinstance(node.op, ast.Div):
333 return left / right
334 raise SyntaxError('invalid syntax')
335
336 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()')
337
338 @classmethod
339 def fromstring(cls, string):
340 """
341 Create an expression from a string. Raise SyntaxError if the string is
342 not properly formatted.
343 """
344 # Add implicit multiplication operators, e.g. '5x' -> '5*x'.
345 string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
346 tree = ast.parse(string, 'eval')
347 expr = cls._fromast(tree)
348 if not isinstance(expr, cls):
349 raise SyntaxError('invalid syntax')
350 return expr
351
352 def __repr__(self):
353 string = ''
354 for i, (symbol, coefficient) in enumerate(self.coefficients()):
355 if coefficient == 1:
356 if i != 0:
357 string += ' + '
358 elif coefficient == -1:
359 string += '-' if i == 0 else ' - '
360 elif i == 0:
361 string += '{}*'.format(coefficient)
362 elif coefficient > 0:
363 string += ' + {}*'.format(coefficient)
364 else:
365 string += ' - {}*'.format(-coefficient)
366 string += '{}'.format(symbol)
367 constant = self.constant
368 if len(string) == 0:
369 string += '{}'.format(constant)
370 elif constant > 0:
371 string += ' + {}'.format(constant)
372 elif constant < 0:
373 string += ' - {}'.format(-constant)
374 return string
375
376 def _repr_latex_(self):
377 string = ''
378 for i, (symbol, coefficient) in enumerate(self.coefficients()):
379 if coefficient == 1:
380 if i != 0:
381 string += ' + '
382 elif coefficient == -1:
383 string += '-' if i == 0 else ' - '
384 elif i == 0:
385 string += '{}'.format(coefficient._repr_latex_().strip('$'))
386 elif coefficient > 0:
387 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
388 elif coefficient < 0:
389 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
390 string += '{}'.format(symbol._repr_latex_().strip('$'))
391 constant = self.constant
392 if len(string) == 0:
393 string += '{}'.format(constant._repr_latex_().strip('$'))
394 elif constant > 0:
395 string += ' + {}'.format(constant._repr_latex_().strip('$'))
396 elif constant < 0:
397 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
398 return '$${}$$'.format(string)
399
400 def _parenstr(self, always=False):
401 string = str(self)
402 if not always and (self.isconstant() or self.issymbol()):
403 return string
404 else:
405 return '({})'.format(string)
406
407 @classmethod
408 def fromsympy(cls, expr):
409 """
410 Create a linear expression from a SymPy expression. Raise TypeError is
411 the sympy expression is not linear.
412 """
413 import sympy
414 coefficients = []
415 constant = 0
416 for symbol, coefficient in expr.as_coefficients_dict().items():
417 coefficient = Fraction(coefficient.p, coefficient.q)
418 if symbol == sympy.S.One:
419 constant = coefficient
420 elif isinstance(symbol, sympy.Dummy):
421 # We cannot properly convert dummy symbols with respect to
422 # symbol equalities.
423 raise TypeError('cannot convert dummy symbols')
424 elif isinstance(symbol, sympy.Symbol):
425 symbol = Symbol(symbol.name)
426 coefficients.append((symbol, coefficient))
427 else:
428 raise TypeError('non-linear expression: {!r}'.format(expr))
429 expr = LinExpr(coefficients, constant)
430 if not isinstance(expr, cls):
431 raise TypeError('cannot convert to a {} instance'.format(cls.__name__))
432 return expr
433
434 def tosympy(self):
435 """
436 Convert the linear expression to a SymPy expression.
437 """
438 import sympy
439 expr = 0
440 for symbol, coefficient in self.coefficients():
441 term = coefficient * sympy.Symbol(symbol.name)
442 expr += term
443 expr += self.constant
444 return expr
445
446
447 class Symbol(LinExpr):
448 """
449 Symbols are the basic components to build expressions and constraints.
450 They correspond to mathematical variables. Symbols are instances of
451 class LinExpr and inherit its functionalities.
452
453 Two instances of Symbol are equal if they have the same name.
454 """
455
456 __slots__ = (
457 '_name',
458 '_constant',
459 '_symbols',
460 '_dimension',
461 )
462
463 def __new__(cls, name):
464 """
465 Return a symbol with the name string given in argument.
466 """
467 if not isinstance(name, str):
468 raise TypeError('name must be a string')
469 node = ast.parse(name)
470 try:
471 name = node.body[0].value.id
472 except (AttributeError, SyntaxError):
473 raise SyntaxError('invalid syntax')
474 self = object().__new__(cls)
475 self._name = name
476 self._constant = Fraction(0)
477 self._symbols = (self,)
478 self._dimension = 1
479 return self
480
481 @property
482 def _coefficients(self):
483 # This is not implemented as an attribute, because __hash__ is not
484 # callable in __new__ in class Dummy.
485 return {self: Fraction(1)}
486
487 @property
488 def name(self):
489 """
490 The name of the symbol.
491 """
492 return self._name
493
494 def __hash__(self):
495 return hash(self.sortkey())
496
497 def sortkey(self):
498 """
499 Return a sorting key for the symbol. It is useful to sort a list of
500 symbols in a consistent order, as comparison functions are overridden
501 (see the documentation of class LinExpr).
502
503 >>> sort(symbols, key=Symbol.sortkey)
504 """
505 return self.name,
506
507 def issymbol(self):
508 return True
509
510 def __eq__(self, other):
511 if isinstance(other, Symbol):
512 return self.sortkey() == other.sortkey()
513 return NotImplemented
514
515 def asdummy(self):
516 """
517 Return a new Dummy symbol instance with the same name.
518 """
519 return Dummy(self.name)
520
521 def __repr__(self):
522 return self.name
523
524 def _repr_latex_(self):
525 return '$${}$$'.format(self.name)
526
527
528 def symbols(names):
529 """
530 This function returns a tuple of symbols whose names are taken from a comma
531 or whitespace delimited string, or a sequence of strings. It is useful to
532 define several symbols at once.
533
534 >>> x, y = symbols('x y')
535 >>> x, y = symbols('x, y')
536 >>> x, y = symbols(['x', 'y'])
537 """
538 if isinstance(names, str):
539 names = names.replace(',', ' ').split()
540 return tuple(Symbol(name) for name in names)
541
542
543 class Dummy(Symbol):
544 """
545 A variation of Symbol in which all symbols are unique and identified by
546 an internal count index. If a name is not supplied then a string value
547 of the count index will be used. This is useful when a unique, temporary
548 variable is needed and the name of the variable used in the expression
549 is not important.
550
551 Unlike Symbol, Dummy instances with the same name are not equal:
552
553 >>> x = Symbol('x')
554 >>> x1, x2 = Dummy('x'), Dummy('x')
555 >>> x == x1
556 False
557 >>> x1 == x2
558 False
559 >>> x1 == x1
560 True
561 """
562
563 _count = 0
564
565 def __new__(cls, name=None):
566 """
567 Return a fresh dummy symbol with the name string given in argument.
568 """
569 if name is None:
570 name = 'Dummy_{}'.format(Dummy._count)
571 self = super().__new__(cls, name)
572 self._index = Dummy._count
573 Dummy._count += 1
574 return self
575
576 def __hash__(self):
577 return hash(self.sortkey())
578
579 def sortkey(self):
580 return self._name, self._index
581
582 def __repr__(self):
583 return '_{}'.format(self.name)
584
585 def _repr_latex_(self):
586 return '$${}_{{{}}}$$'.format(self.name, self._index)
587
588
589 class Rational(LinExpr, Fraction):
590 """
591 A particular case of linear expressions are rational values, i.e. linear
592 expressions consisting only of a constant term, with no symbol. They are
593 implemented by the Rational class, that inherits from both LinExpr and
594 fractions.Fraction classes.
595 """
596
597 __slots__ = (
598 '_coefficients',
599 '_constant',
600 '_symbols',
601 '_dimension',
602 ) + Fraction.__slots__
603
604 def __new__(cls, numerator=0, denominator=None):
605 self = object().__new__(cls)
606 self._coefficients = {}
607 self._constant = Fraction(numerator, denominator)
608 self._symbols = ()
609 self._dimension = 0
610 self._numerator = self._constant.numerator
611 self._denominator = self._constant.denominator
612 return self
613
614 def __hash__(self):
615 return Fraction.__hash__(self)
616
617 @property
618 def constant(self):
619 return self
620
621 def isconstant(self):
622 return True
623
624 def __bool__(self):
625 return Fraction.__bool__(self)
626
627 def __repr__(self):
628 if self.denominator == 1:
629 return '{!r}'.format(self.numerator)
630 else:
631 return '{!r}/{!r}'.format(self.numerator, self.denominator)
632
633 def _repr_latex_(self):
634 if self.denominator == 1:
635 return '$${}$$'.format(self.numerator)
636 elif self.numerator < 0:
637 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
638 self.denominator)
639 else:
640 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
641 self.denominator)