Simplify LinExpr.values()
[linpy.git] / linpy / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of LinPy.
4 #
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import OrderedDict, defaultdict, Mapping
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'LinExpr',
29 'Symbol', 'Dummy', 'symbols',
30 'Rational',
31 ]
32
33
34 def _polymorphic(func):
35 @functools.wraps(func)
36 def wrapper(left, right):
37 if isinstance(right, LinExpr):
38 return func(left, right)
39 elif isinstance(right, numbers.Rational):
40 right = Rational(right)
41 return func(left, right)
42 return NotImplemented
43 return wrapper
44
45
46 class LinExpr:
47 """
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
51 have short lifespans.
52
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
55
56 LinExpr instances are hashable, and should be treated as immutable.
57 """
58
59 def __new__(cls, coefficients=None, constant=0):
60 """
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
64
65 For example, the linear expression x + 2y + 1 can be constructed using
66 one of the following instructions:
67
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
71
72 However, it may be easier to use overloaded operators:
73
74 >>> x, y = symbols('x y')
75 >>> x + 2*y + 1
76
77 Alternatively, linear expressions can be constructed from a string:
78
79 >>> LinExpr('x + 2*y + 1')
80
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
85 """
86 if isinstance(coefficients, str):
87 if constant != 0:
88 raise TypeError('too many arguments')
89 return LinExpr.fromstring(coefficients)
90 if coefficients is None:
91 return Rational(constant)
92 if isinstance(coefficients, Mapping):
93 coefficients = coefficients.items()
94 coefficients = list(coefficients)
95 for symbol, coefficient in coefficients:
96 if not isinstance(symbol, Symbol):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient, numbers.Rational):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant, numbers.Rational):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients) == 0:
103 return Rational(constant)
104 if len(coefficients) == 1 and constant == 0:
105 symbol, coefficient = coefficients[0]
106 if coefficient == 1:
107 return symbol
108 coefficients = [(symbol, Fraction(coefficient))
109 for symbol, coefficient in coefficients if coefficient != 0]
110 coefficients.sort(key=lambda item: item[0].sortkey())
111 self = object().__new__(cls)
112 self._coefficients = OrderedDict(coefficients)
113 self._constant = Fraction(constant)
114 self._symbols = tuple(self._coefficients)
115 self._dimension = len(self._symbols)
116 return self
117
118 def coefficient(self, symbol):
119 """
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
122 """
123 if not isinstance(symbol, Symbol):
124 raise TypeError('symbol must be a Symbol instance')
125 return self._coefficients.get(symbol, Fraction(0))
126
127 __getitem__ = coefficient
128
129 def coefficients(self):
130 """
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
133 """
134 yield from self._coefficients.items()
135
136 @property
137 def constant(self):
138 """
139 The constant term of the expression.
140 """
141 return self._constant
142
143 @property
144 def symbols(self):
145 """
146 The tuple of symbols present in the expression, sorted according to
147 Symbol.sortkey().
148 """
149 return self._symbols
150
151 @property
152 def dimension(self):
153 """
154 The dimension of the expression, i.e. the number of symbols present in
155 it.
156 """
157 return self._dimension
158
159 def __hash__(self):
160 return hash((tuple(self._coefficients.items()), self._constant))
161
162 def isconstant(self):
163 """
164 Return True if the expression only consists of a constant term. In this
165 case, it is a Rational instance.
166 """
167 return False
168
169 def issymbol(self):
170 """
171 Return True if an expression only consists of a symbol with coefficient
172 1. In this case, it is a Symbol instance.
173 """
174 return False
175
176 def values(self):
177 """
178 Iterate over the coefficient values in the expression, and the constant
179 term.
180 """
181 yield from self._coefficients.values()
182 yield self._constant
183
184 def __bool__(self):
185 return True
186
187 def __pos__(self):
188 return self
189
190 def __neg__(self):
191 return self * -1
192
193 @_polymorphic
194 def __add__(self, other):
195 """
196 Return the sum of two linear expressions.
197 """
198 coefficients = defaultdict(Fraction, self._coefficients)
199 for symbol, coefficient in other._coefficients.items():
200 coefficients[symbol] += coefficient
201 constant = self._constant + other._constant
202 return LinExpr(coefficients, constant)
203
204 __radd__ = __add__
205
206 @_polymorphic
207 def __sub__(self, other):
208 """
209 Return the difference between two linear expressions.
210 """
211 coefficients = defaultdict(Fraction, self._coefficients)
212 for symbol, coefficient in other._coefficients.items():
213 coefficients[symbol] -= coefficient
214 constant = self._constant - other._constant
215 return LinExpr(coefficients, constant)
216
217 @_polymorphic
218 def __rsub__(self, other):
219 return other - self
220
221 def __mul__(self, other):
222 """
223 Return the product of the linear expression by a rational.
224 """
225 if isinstance(other, numbers.Rational):
226 coefficients = ((symbol, coefficient * other)
227 for symbol, coefficient in self._coefficients.items())
228 constant = self._constant * other
229 return LinExpr(coefficients, constant)
230 return NotImplemented
231
232 __rmul__ = __mul__
233
234 def __truediv__(self, other):
235 """
236 Return the quotient of the linear expression by a rational.
237 """
238 if isinstance(other, numbers.Rational):
239 coefficients = ((symbol, coefficient / other)
240 for symbol, coefficient in self._coefficients.items())
241 constant = self._constant / other
242 return LinExpr(coefficients, constant)
243 return NotImplemented
244
245 @_polymorphic
246 def __eq__(self, other):
247 """
248 Test whether two linear expressions are equal.
249 """
250 return isinstance(other, LinExpr) and \
251 self._coefficients == other._coefficients and \
252 self._constant == other._constant
253
254 def __le__(self, other):
255 from .polyhedra import Le
256 return Le(self, other)
257
258 def __lt__(self, other):
259 from .polyhedra import Lt
260 return Lt(self, other)
261
262 def __ge__(self, other):
263 from .polyhedra import Ge
264 return Ge(self, other)
265
266 def __gt__(self, other):
267 from .polyhedra import Gt
268 return Gt(self, other)
269
270 def scaleint(self):
271 """
272 Return the expression multiplied by its lowest common denominator to
273 make all values integer.
274 """
275 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
276 [value.denominator for value in self.values()])
277 return self * lcm
278
279 def subs(self, symbol, expression=None):
280 """
281 Substitute the given symbol by an expression and return the resulting
282 expression. Raise TypeError if the resulting expression is not linear.
283
284 >>> x, y = symbols('x y')
285 >>> e = x + 2*y + 1
286 >>> e.subs(y, x - 1)
287 3*x - 1
288
289 To perform multiple substitutions at once, pass a sequence or a
290 dictionary of (old, new) pairs to subs.
291
292 >>> e.subs({x: y, y: x})
293 2*x + y + 1
294 """
295 if expression is None:
296 if isinstance(symbol, Mapping):
297 symbol = symbol.items()
298 substitutions = symbol
299 else:
300 substitutions = [(symbol, expression)]
301 result = self
302 for symbol, expression in substitutions:
303 if not isinstance(symbol, Symbol):
304 raise TypeError('symbols must be Symbol instances')
305 coefficients = [(othersymbol, coefficient)
306 for othersymbol, coefficient in result._coefficients.items()
307 if othersymbol != symbol]
308 coefficient = result._coefficients.get(symbol, 0)
309 constant = result._constant
310 result = LinExpr(coefficients, constant) + coefficient*expression
311 return result
312
313 @classmethod
314 def _fromast(cls, node):
315 if isinstance(node, ast.Module) and len(node.body) == 1:
316 return cls._fromast(node.body[0])
317 elif isinstance(node, ast.Expr):
318 return cls._fromast(node.value)
319 elif isinstance(node, ast.Name):
320 return Symbol(node.id)
321 elif isinstance(node, ast.Num):
322 return Rational(node.n)
323 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
324 return -cls._fromast(node.operand)
325 elif isinstance(node, ast.BinOp):
326 left = cls._fromast(node.left)
327 right = cls._fromast(node.right)
328 if isinstance(node.op, ast.Add):
329 return left + right
330 elif isinstance(node.op, ast.Sub):
331 return left - right
332 elif isinstance(node.op, ast.Mult):
333 return left * right
334 elif isinstance(node.op, ast.Div):
335 return left / right
336 raise SyntaxError('invalid syntax')
337
338 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
339
340 @classmethod
341 def fromstring(cls, string):
342 """
343 Create an expression from a string. Raise SyntaxError if the string is
344 not properly formatted.
345 """
346 # add implicit multiplication operators, e.g. '5x' -> '5*x'
347 string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
348 tree = ast.parse(string, 'eval')
349 expr = cls._fromast(tree)
350 if not isinstance(expr, cls):
351 raise SyntaxError('invalid syntax')
352 return expr
353
354 def __repr__(self):
355 string = ''
356 for i, (symbol, coefficient) in enumerate(self.coefficients()):
357 if coefficient == 1:
358 if i != 0:
359 string += ' + '
360 elif coefficient == -1:
361 string += '-' if i == 0 else ' - '
362 elif i == 0:
363 string += '{}*'.format(coefficient)
364 elif coefficient > 0:
365 string += ' + {}*'.format(coefficient)
366 else:
367 string += ' - {}*'.format(-coefficient)
368 string += '{}'.format(symbol)
369 constant = self.constant
370 if len(string) == 0:
371 string += '{}'.format(constant)
372 elif constant > 0:
373 string += ' + {}'.format(constant)
374 elif constant < 0:
375 string += ' - {}'.format(-constant)
376 return string
377
378 def _repr_latex_(self):
379 string = ''
380 for i, (symbol, coefficient) in enumerate(self.coefficients()):
381 if coefficient == 1:
382 if i != 0:
383 string += ' + '
384 elif coefficient == -1:
385 string += '-' if i == 0 else ' - '
386 elif i == 0:
387 string += '{}'.format(coefficient._repr_latex_().strip('$'))
388 elif coefficient > 0:
389 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
390 elif coefficient < 0:
391 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
392 string += '{}'.format(symbol._repr_latex_().strip('$'))
393 constant = self.constant
394 if len(string) == 0:
395 string += '{}'.format(constant._repr_latex_().strip('$'))
396 elif constant > 0:
397 string += ' + {}'.format(constant._repr_latex_().strip('$'))
398 elif constant < 0:
399 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
400 return '$${}$$'.format(string)
401
402 def _parenstr(self, always=False):
403 string = str(self)
404 if not always and (self.isconstant() or self.issymbol()):
405 return string
406 else:
407 return '({})'.format(string)
408
409 @classmethod
410 def fromsympy(cls, expr):
411 """
412 Create a linear expression from a sympy expression. Raise TypeError is
413 the sympy expression is not linear.
414 """
415 import sympy
416 coefficients = []
417 constant = 0
418 for symbol, coefficient in expr.as_coefficients_dict().items():
419 coefficient = Fraction(coefficient.p, coefficient.q)
420 if symbol == sympy.S.One:
421 constant = coefficient
422 elif isinstance(symbol, sympy.Dummy):
423 # we cannot properly convert dummy symbols
424 raise TypeError('cannot convert dummy symbols')
425 elif isinstance(symbol, sympy.Symbol):
426 symbol = Symbol(symbol.name)
427 coefficients.append((symbol, coefficient))
428 else:
429 raise TypeError('non-linear expression: {!r}'.format(expr))
430 expr = LinExpr(coefficients, constant)
431 if not isinstance(expr, cls):
432 raise TypeError('cannot convert to a {} instance'.format(cls.__name__))
433 return expr
434
435 def tosympy(self):
436 """
437 Convert the linear expression to a sympy expression.
438 """
439 import sympy
440 expr = 0
441 for symbol, coefficient in self.coefficients():
442 term = coefficient * sympy.Symbol(symbol.name)
443 expr += term
444 expr += self.constant
445 return expr
446
447
448 class Symbol(LinExpr):
449 """
450 Symbols are the basic components to build expressions and constraints.
451 They correspond to mathematical variables. Symbols are instances of
452 class LinExpr and inherit its functionalities.
453
454 Two instances of Symbol are equal if they have the same name.
455 """
456
457 def __new__(cls, name):
458 """
459 Return a symbol with the name string given in argument.
460 """
461 if not isinstance(name, str):
462 raise TypeError('name must be a string')
463 node = ast.parse(name)
464 try:
465 name = node.body[0].value.id
466 except (AttributeError, SyntaxError):
467 raise SyntaxError('invalid syntax')
468 self = object().__new__(cls)
469 self._name = name
470 self._coefficients = {self: Fraction(1)}
471 self._constant = Fraction(0)
472 self._symbols = (self,)
473 self._dimension = 1
474 return self
475
476 @property
477 def name(self):
478 """
479 The name of the symbol.
480 """
481 return self._name
482
483 def __hash__(self):
484 return hash(self.sortkey())
485
486 def sortkey(self):
487 """
488 Return a sorting key for the symbol. It is useful to sort a list of
489 symbols in a consistent order, as comparison functions are overridden
490 (see the documentation of class LinExpr).
491
492 >>> sort(symbols, key=Symbol.sortkey)
493 """
494 return self.name,
495
496 def issymbol(self):
497 return True
498
499 def __eq__(self, other):
500 return self.sortkey() == other.sortkey()
501
502 def asdummy(self):
503 """
504 Return a new Dummy symbol instance with the same name.
505 """
506 return Dummy(self.name)
507
508 def __repr__(self):
509 return self.name
510
511 def _repr_latex_(self):
512 return '$${}$$'.format(self.name)
513
514
515 def symbols(names):
516 """
517 This function returns a tuple of symbols whose names are taken from a comma
518 or whitespace delimited string, or a sequence of strings. It is useful to
519 define several symbols at once.
520
521 >>> x, y = symbols('x y')
522 >>> x, y = symbols('x, y')
523 >>> x, y = symbols(['x', 'y'])
524 """
525 if isinstance(names, str):
526 names = names.replace(',', ' ').split()
527 return tuple(Symbol(name) for name in names)
528
529
530 class Dummy(Symbol):
531 """
532 A variation of Symbol in which all symbols are unique and identified by
533 an internal count index. If a name is not supplied then a string value
534 of the count index will be used. This is useful when a unique, temporary
535 variable is needed and the name of the variable used in the expression
536 is not important.
537
538 Unlike Symbol, Dummy instances with the same name are not equal:
539
540 >>> x = Symbol('x')
541 >>> x1, x2 = Dummy('x'), Dummy('x')
542 >>> x == x1
543 False
544 >>> x1 == x2
545 False
546 >>> x1 == x1
547 True
548 """
549
550 _count = 0
551
552 def __new__(cls, name=None):
553 """
554 Return a fresh dummy symbol with the name string given in argument.
555 """
556 if name is None:
557 name = 'Dummy_{}'.format(Dummy._count)
558 elif not isinstance(name, str):
559 raise TypeError('name must be a string')
560 self = object().__new__(cls)
561 self._index = Dummy._count
562 self._name = name.strip()
563 self._coefficients = {self: Fraction(1)}
564 self._constant = Fraction(0)
565 self._symbols = (self,)
566 self._dimension = 1
567 Dummy._count += 1
568 return self
569
570 def __hash__(self):
571 return hash(self.sortkey())
572
573 def sortkey(self):
574 return self._name, self._index
575
576 def __repr__(self):
577 return '_{}'.format(self.name)
578
579 def _repr_latex_(self):
580 return '$${}_{{{}}}$$'.format(self.name, self._index)
581
582
583 class Rational(LinExpr, Fraction):
584 """
585 A particular case of linear expressions are rational values, i.e. linear
586 expressions consisting only of a constant term, with no symbol. They are
587 implemented by the Rational class, that inherits from both LinExpr and
588 fractions.Fraction classes.
589 """
590
591 def __new__(cls, numerator=0, denominator=None):
592 self = object().__new__(cls)
593 self._coefficients = {}
594 self._constant = Fraction(numerator, denominator)
595 self._symbols = ()
596 self._dimension = 0
597 self._numerator = self._constant.numerator
598 self._denominator = self._constant.denominator
599 return self
600
601 def __hash__(self):
602 return Fraction.__hash__(self)
603
604 @property
605 def constant(self):
606 return self
607
608 def isconstant(self):
609 return True
610
611 def __bool__(self):
612 return Fraction.__bool__(self)
613
614 def __repr__(self):
615 if self.denominator == 1:
616 return '{!r}'.format(self.numerator)
617 else:
618 return '{!r}/{!r}'.format(self.numerator, self.denominator)
619
620 def _repr_latex_(self):
621 if self.denominator == 1:
622 return '$${}$$'.format(self.numerator)
623 elif self.numerator < 0:
624 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
625 self.denominator)
626 else:
627 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
628 self.denominator)