Fix Symbol == LinExpr comparisons
[linpy.git] / linpy / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of LinPy.
4 #
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import OrderedDict, defaultdict, Mapping
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'LinExpr',
29 'Symbol', 'Dummy', 'symbols',
30 'Rational',
31 ]
32
33
34 def _polymorphic(func):
35 @functools.wraps(func)
36 def wrapper(left, right):
37 if isinstance(right, LinExpr):
38 return func(left, right)
39 elif isinstance(right, numbers.Rational):
40 right = Rational(right)
41 return func(left, right)
42 return NotImplemented
43 return wrapper
44
45
46 class LinExpr:
47 """
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
51 have short lifespans.
52
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
55
56 LinExpr instances are hashable, and should be treated as immutable.
57 """
58
59 def __new__(cls, coefficients=None, constant=0):
60 """
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
64
65 For example, the linear expression x + 2y + 1 can be constructed using
66 one of the following instructions:
67
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
71
72 However, it may be easier to use overloaded operators:
73
74 >>> x, y = symbols('x y')
75 >>> x + 2*y + 1
76
77 Alternatively, linear expressions can be constructed from a string:
78
79 >>> LinExpr('x + 2*y + 1')
80
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
85 """
86 if isinstance(coefficients, str):
87 if constant != 0:
88 raise TypeError('too many arguments')
89 return LinExpr.fromstring(coefficients)
90 if coefficients is None:
91 return Rational(constant)
92 if isinstance(coefficients, Mapping):
93 coefficients = coefficients.items()
94 coefficients = list(coefficients)
95 for symbol, coefficient in coefficients:
96 if not isinstance(symbol, Symbol):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient, numbers.Rational):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant, numbers.Rational):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients) == 0:
103 return Rational(constant)
104 if len(coefficients) == 1 and constant == 0:
105 symbol, coefficient = coefficients[0]
106 if coefficient == 1:
107 return symbol
108 coefficients = [(symbol, Fraction(coefficient))
109 for symbol, coefficient in coefficients if coefficient != 0]
110 coefficients.sort(key=lambda item: item[0].sortkey())
111 self = object().__new__(cls)
112 self._coefficients = OrderedDict(coefficients)
113 self._constant = Fraction(constant)
114 self._symbols = tuple(self._coefficients)
115 self._dimension = len(self._symbols)
116 return self
117
118 def coefficient(self, symbol):
119 """
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
122 """
123 if not isinstance(symbol, Symbol):
124 raise TypeError('symbol must be a Symbol instance')
125 return self._coefficients.get(symbol, Fraction(0))
126
127 __getitem__ = coefficient
128
129 def coefficients(self):
130 """
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
133 """
134 yield from self._coefficients.items()
135
136 @property
137 def constant(self):
138 """
139 The constant term of the expression.
140 """
141 return self._constant
142
143 @property
144 def symbols(self):
145 """
146 The tuple of symbols present in the expression, sorted according to
147 Symbol.sortkey().
148 """
149 return self._symbols
150
151 @property
152 def dimension(self):
153 """
154 The dimension of the expression, i.e. the number of symbols present in
155 it.
156 """
157 return self._dimension
158
159 def __hash__(self):
160 return hash((tuple(self._coefficients.items()), self._constant))
161
162 def isconstant(self):
163 """
164 Return True if the expression only consists of a constant term. In this
165 case, it is a Rational instance.
166 """
167 return False
168
169 def issymbol(self):
170 """
171 Return True if an expression only consists of a symbol with coefficient
172 1. In this case, it is a Symbol instance.
173 """
174 return False
175
176 def values(self):
177 """
178 Iterate over the coefficient values in the expression, and the constant
179 term.
180 """
181 yield from self._coefficients.values()
182 yield self._constant
183
184 def __bool__(self):
185 return True
186
187 def __pos__(self):
188 return self
189
190 def __neg__(self):
191 return self * -1
192
193 @_polymorphic
194 def __add__(self, other):
195 """
196 Return the sum of two linear expressions.
197 """
198 coefficients = defaultdict(Fraction, self._coefficients)
199 for symbol, coefficient in other._coefficients.items():
200 coefficients[symbol] += coefficient
201 constant = self._constant + other._constant
202 return LinExpr(coefficients, constant)
203
204 __radd__ = __add__
205
206 @_polymorphic
207 def __sub__(self, other):
208 """
209 Return the difference between two linear expressions.
210 """
211 coefficients = defaultdict(Fraction, self._coefficients)
212 for symbol, coefficient in other._coefficients.items():
213 coefficients[symbol] -= coefficient
214 constant = self._constant - other._constant
215 return LinExpr(coefficients, constant)
216
217 @_polymorphic
218 def __rsub__(self, other):
219 return other - self
220
221 def __mul__(self, other):
222 """
223 Return the product of the linear expression by a rational.
224 """
225 if isinstance(other, numbers.Rational):
226 coefficients = ((symbol, coefficient * other)
227 for symbol, coefficient in self._coefficients.items())
228 constant = self._constant * other
229 return LinExpr(coefficients, constant)
230 return NotImplemented
231
232 __rmul__ = __mul__
233
234 def __truediv__(self, other):
235 """
236 Return the quotient of the linear expression by a rational.
237 """
238 if isinstance(other, numbers.Rational):
239 coefficients = ((symbol, coefficient / other)
240 for symbol, coefficient in self._coefficients.items())
241 constant = self._constant / other
242 return LinExpr(coefficients, constant)
243 return NotImplemented
244
245 @_polymorphic
246 def __eq__(self, other):
247 """
248 Test whether two linear expressions are equal.
249 """
250 if isinstance(other, LinExpr):
251 return self._coefficients == other._coefficients and \
252 self._constant == other._constant
253 return NotImplemented
254
255 def __le__(self, other):
256 from .polyhedra import Le
257 return Le(self, other)
258
259 def __lt__(self, other):
260 from .polyhedra import Lt
261 return Lt(self, other)
262
263 def __ge__(self, other):
264 from .polyhedra import Ge
265 return Ge(self, other)
266
267 def __gt__(self, other):
268 from .polyhedra import Gt
269 return Gt(self, other)
270
271 def scaleint(self):
272 """
273 Return the expression multiplied by its lowest common denominator to
274 make all values integer.
275 """
276 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
277 [value.denominator for value in self.values()])
278 return self * lcm
279
280 def subs(self, symbol, expression=None):
281 """
282 Substitute the given symbol by an expression and return the resulting
283 expression. Raise TypeError if the resulting expression is not linear.
284
285 >>> x, y = symbols('x y')
286 >>> e = x + 2*y + 1
287 >>> e.subs(y, x - 1)
288 3*x - 1
289
290 To perform multiple substitutions at once, pass a sequence or a
291 dictionary of (old, new) pairs to subs.
292
293 >>> e.subs({x: y, y: x})
294 2*x + y + 1
295 """
296 if expression is None:
297 if isinstance(symbol, Mapping):
298 symbol = symbol.items()
299 substitutions = symbol
300 else:
301 substitutions = [(symbol, expression)]
302 result = self
303 for symbol, expression in substitutions:
304 if not isinstance(symbol, Symbol):
305 raise TypeError('symbols must be Symbol instances')
306 coefficients = [(othersymbol, coefficient)
307 for othersymbol, coefficient in result._coefficients.items()
308 if othersymbol != symbol]
309 coefficient = result._coefficients.get(symbol, 0)
310 constant = result._constant
311 result = LinExpr(coefficients, constant) + coefficient*expression
312 return result
313
314 @classmethod
315 def _fromast(cls, node):
316 if isinstance(node, ast.Module) and len(node.body) == 1:
317 return cls._fromast(node.body[0])
318 elif isinstance(node, ast.Expr):
319 return cls._fromast(node.value)
320 elif isinstance(node, ast.Name):
321 return Symbol(node.id)
322 elif isinstance(node, ast.Num):
323 return Rational(node.n)
324 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
325 return -cls._fromast(node.operand)
326 elif isinstance(node, ast.BinOp):
327 left = cls._fromast(node.left)
328 right = cls._fromast(node.right)
329 if isinstance(node.op, ast.Add):
330 return left + right
331 elif isinstance(node.op, ast.Sub):
332 return left - right
333 elif isinstance(node.op, ast.Mult):
334 return left * right
335 elif isinstance(node.op, ast.Div):
336 return left / right
337 raise SyntaxError('invalid syntax')
338
339 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
340
341 @classmethod
342 def fromstring(cls, string):
343 """
344 Create an expression from a string. Raise SyntaxError if the string is
345 not properly formatted.
346 """
347 # add implicit multiplication operators, e.g. '5x' -> '5*x'
348 string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
349 tree = ast.parse(string, 'eval')
350 expr = cls._fromast(tree)
351 if not isinstance(expr, cls):
352 raise SyntaxError('invalid syntax')
353 return expr
354
355 def __repr__(self):
356 string = ''
357 for i, (symbol, coefficient) in enumerate(self.coefficients()):
358 if coefficient == 1:
359 if i != 0:
360 string += ' + '
361 elif coefficient == -1:
362 string += '-' if i == 0 else ' - '
363 elif i == 0:
364 string += '{}*'.format(coefficient)
365 elif coefficient > 0:
366 string += ' + {}*'.format(coefficient)
367 else:
368 string += ' - {}*'.format(-coefficient)
369 string += '{}'.format(symbol)
370 constant = self.constant
371 if len(string) == 0:
372 string += '{}'.format(constant)
373 elif constant > 0:
374 string += ' + {}'.format(constant)
375 elif constant < 0:
376 string += ' - {}'.format(-constant)
377 return string
378
379 def _repr_latex_(self):
380 string = ''
381 for i, (symbol, coefficient) in enumerate(self.coefficients()):
382 if coefficient == 1:
383 if i != 0:
384 string += ' + '
385 elif coefficient == -1:
386 string += '-' if i == 0 else ' - '
387 elif i == 0:
388 string += '{}'.format(coefficient._repr_latex_().strip('$'))
389 elif coefficient > 0:
390 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
391 elif coefficient < 0:
392 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
393 string += '{}'.format(symbol._repr_latex_().strip('$'))
394 constant = self.constant
395 if len(string) == 0:
396 string += '{}'.format(constant._repr_latex_().strip('$'))
397 elif constant > 0:
398 string += ' + {}'.format(constant._repr_latex_().strip('$'))
399 elif constant < 0:
400 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
401 return '$${}$$'.format(string)
402
403 def _parenstr(self, always=False):
404 string = str(self)
405 if not always and (self.isconstant() or self.issymbol()):
406 return string
407 else:
408 return '({})'.format(string)
409
410 @classmethod
411 def fromsympy(cls, expr):
412 """
413 Create a linear expression from a sympy expression. Raise TypeError is
414 the sympy expression is not linear.
415 """
416 import sympy
417 coefficients = []
418 constant = 0
419 for symbol, coefficient in expr.as_coefficients_dict().items():
420 coefficient = Fraction(coefficient.p, coefficient.q)
421 if symbol == sympy.S.One:
422 constant = coefficient
423 elif isinstance(symbol, sympy.Dummy):
424 # we cannot properly convert dummy symbols
425 raise TypeError('cannot convert dummy symbols')
426 elif isinstance(symbol, sympy.Symbol):
427 symbol = Symbol(symbol.name)
428 coefficients.append((symbol, coefficient))
429 else:
430 raise TypeError('non-linear expression: {!r}'.format(expr))
431 expr = LinExpr(coefficients, constant)
432 if not isinstance(expr, cls):
433 raise TypeError('cannot convert to a {} instance'.format(cls.__name__))
434 return expr
435
436 def tosympy(self):
437 """
438 Convert the linear expression to a sympy expression.
439 """
440 import sympy
441 expr = 0
442 for symbol, coefficient in self.coefficients():
443 term = coefficient * sympy.Symbol(symbol.name)
444 expr += term
445 expr += self.constant
446 return expr
447
448
449 class Symbol(LinExpr):
450 """
451 Symbols are the basic components to build expressions and constraints.
452 They correspond to mathematical variables. Symbols are instances of
453 class LinExpr and inherit its functionalities.
454
455 Two instances of Symbol are equal if they have the same name.
456 """
457
458 def __new__(cls, name):
459 """
460 Return a symbol with the name string given in argument.
461 """
462 if not isinstance(name, str):
463 raise TypeError('name must be a string')
464 node = ast.parse(name)
465 try:
466 name = node.body[0].value.id
467 except (AttributeError, SyntaxError):
468 raise SyntaxError('invalid syntax')
469 self = object().__new__(cls)
470 self._name = name
471 self._coefficients = {self: Fraction(1)}
472 self._constant = Fraction(0)
473 self._symbols = (self,)
474 self._dimension = 1
475 return self
476
477 @property
478 def name(self):
479 """
480 The name of the symbol.
481 """
482 return self._name
483
484 def __hash__(self):
485 return hash(self.sortkey())
486
487 def sortkey(self):
488 """
489 Return a sorting key for the symbol. It is useful to sort a list of
490 symbols in a consistent order, as comparison functions are overridden
491 (see the documentation of class LinExpr).
492
493 >>> sort(symbols, key=Symbol.sortkey)
494 """
495 return self.name,
496
497 def issymbol(self):
498 return True
499
500 def __eq__(self, other):
501 if isinstance(other, Symbol):
502 return self.sortkey() == other.sortkey()
503 return NotImplemented
504
505 def asdummy(self):
506 """
507 Return a new Dummy symbol instance with the same name.
508 """
509 return Dummy(self.name)
510
511 def __repr__(self):
512 return self.name
513
514 def _repr_latex_(self):
515 return '$${}$$'.format(self.name)
516
517
518 def symbols(names):
519 """
520 This function returns a tuple of symbols whose names are taken from a comma
521 or whitespace delimited string, or a sequence of strings. It is useful to
522 define several symbols at once.
523
524 >>> x, y = symbols('x y')
525 >>> x, y = symbols('x, y')
526 >>> x, y = symbols(['x', 'y'])
527 """
528 if isinstance(names, str):
529 names = names.replace(',', ' ').split()
530 return tuple(Symbol(name) for name in names)
531
532
533 class Dummy(Symbol):
534 """
535 A variation of Symbol in which all symbols are unique and identified by
536 an internal count index. If a name is not supplied then a string value
537 of the count index will be used. This is useful when a unique, temporary
538 variable is needed and the name of the variable used in the expression
539 is not important.
540
541 Unlike Symbol, Dummy instances with the same name are not equal:
542
543 >>> x = Symbol('x')
544 >>> x1, x2 = Dummy('x'), Dummy('x')
545 >>> x == x1
546 False
547 >>> x1 == x2
548 False
549 >>> x1 == x1
550 True
551 """
552
553 _count = 0
554
555 def __new__(cls, name=None):
556 """
557 Return a fresh dummy symbol with the name string given in argument.
558 """
559 if name is None:
560 name = 'Dummy_{}'.format(Dummy._count)
561 elif not isinstance(name, str):
562 raise TypeError('name must be a string')
563 self = object().__new__(cls)
564 self._index = Dummy._count
565 self._name = name.strip()
566 self._coefficients = {self: Fraction(1)}
567 self._constant = Fraction(0)
568 self._symbols = (self,)
569 self._dimension = 1
570 Dummy._count += 1
571 return self
572
573 def __hash__(self):
574 return hash(self.sortkey())
575
576 def sortkey(self):
577 return self._name, self._index
578
579 def __repr__(self):
580 return '_{}'.format(self.name)
581
582 def _repr_latex_(self):
583 return '$${}_{{{}}}$$'.format(self.name, self._index)
584
585
586 class Rational(LinExpr, Fraction):
587 """
588 A particular case of linear expressions are rational values, i.e. linear
589 expressions consisting only of a constant term, with no symbol. They are
590 implemented by the Rational class, that inherits from both LinExpr and
591 fractions.Fraction classes.
592 """
593
594 def __new__(cls, numerator=0, denominator=None):
595 self = object().__new__(cls)
596 self._coefficients = {}
597 self._constant = Fraction(numerator, denominator)
598 self._symbols = ()
599 self._dimension = 0
600 self._numerator = self._constant.numerator
601 self._denominator = self._constant.denominator
602 return self
603
604 def __hash__(self):
605 return Fraction.__hash__(self)
606
607 @property
608 def constant(self):
609 return self
610
611 def isconstant(self):
612 return True
613
614 def __bool__(self):
615 return Fraction.__bool__(self)
616
617 def __repr__(self):
618 if self.denominator == 1:
619 return '{!r}'.format(self.numerator)
620 else:
621 return '{!r}/{!r}'.format(self.numerator, self.denominator)
622
623 def _repr_latex_(self):
624 if self.denominator == 1:
625 return '$${}$$'.format(self.numerator)
626 elif self.numerator < 0:
627 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
628 self.denominator)
629 else:
630 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
631 self.denominator)