82d75d00229b5646f1fe3a46299e58c4a90360cb
[linpy.git] / linpy / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of LinPy.
4 #
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import OrderedDict, defaultdict, Mapping
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'LinExpr',
29 'Symbol', 'Dummy', 'symbols',
30 'Rational',
31 ]
32
33
34 def _polymorphic(func):
35 @functools.wraps(func)
36 def wrapper(left, right):
37 if isinstance(right, LinExpr):
38 return func(left, right)
39 elif isinstance(right, numbers.Rational):
40 right = Rational(right)
41 return func(left, right)
42 return NotImplemented
43 return wrapper
44
45
46 class LinExpr:
47 """
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
51 have short lifespans.
52
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
55
56 LinExpr instances are hashable, and should be treated as immutable.
57 """
58
59 def __new__(cls, coefficients=None, constant=0):
60 """
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
64
65 For example, the linear expression x + 2y + 1 can be constructed using
66 one of the following instructions:
67
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
71
72 However, it may be easier to use overloaded operators:
73
74 >>> x, y = symbols('x y')
75 >>> x + 2*y + 1
76
77 Alternatively, linear expressions can be constructed from a string:
78
79 >>> LinExpr('x + 2*y + 1')
80
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
85 """
86 if isinstance(coefficients, str):
87 if constant != 0:
88 raise TypeError('too many arguments')
89 return LinExpr.fromstring(coefficients)
90 if coefficients is None:
91 return Rational(constant)
92 if isinstance(coefficients, Mapping):
93 coefficients = coefficients.items()
94 coefficients = list(coefficients)
95 for symbol, coefficient in coefficients:
96 if not isinstance(symbol, Symbol):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient, numbers.Rational):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant, numbers.Rational):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients) == 0:
103 return Rational(constant)
104 if len(coefficients) == 1 and constant == 0:
105 symbol, coefficient = coefficients[0]
106 if coefficient == 1:
107 return symbol
108 coefficients = [(symbol, Fraction(coefficient))
109 for symbol, coefficient in coefficients if coefficient != 0]
110 coefficients.sort(key=lambda item: item[0].sortkey())
111 self = object().__new__(cls)
112 self._coefficients = OrderedDict(coefficients)
113 self._constant = Fraction(constant)
114 self._symbols = tuple(self._coefficients)
115 self._dimension = len(self._symbols)
116 return self
117
118 def coefficient(self, symbol):
119 """
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
122 """
123 if not isinstance(symbol, Symbol):
124 raise TypeError('symbol must be a Symbol instance')
125 return Rational(self._coefficients.get(symbol, 0))
126
127 __getitem__ = coefficient
128
129 def coefficients(self):
130 """
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
133 """
134 for symbol, coefficient in self._coefficients.items():
135 yield symbol, Rational(coefficient)
136
137 @property
138 def constant(self):
139 """
140 The constant term of the expression.
141 """
142 return Rational(self._constant)
143
144 @property
145 def symbols(self):
146 """
147 The tuple of symbols present in the expression, sorted according to
148 Symbol.sortkey().
149 """
150 return self._symbols
151
152 @property
153 def dimension(self):
154 """
155 The dimension of the expression, i.e. the number of symbols present in
156 it.
157 """
158 return self._dimension
159
160 def __hash__(self):
161 return hash((tuple(self._coefficients.items()), self._constant))
162
163 def isconstant(self):
164 """
165 Return True if the expression only consists of a constant term. In this
166 case, it is a Rational instance.
167 """
168 return False
169
170 def issymbol(self):
171 """
172 Return True if an expression only consists of a symbol with coefficient
173 1. In this case, it is a Symbol instance.
174 """
175 return False
176
177 def values(self):
178 """
179 Iterate over the coefficient values in the expression, and the constant
180 term.
181 """
182 for coefficient in self._coefficients.values():
183 yield Rational(coefficient)
184 yield Rational(self._constant)
185
186 def __bool__(self):
187 return True
188
189 def __pos__(self):
190 return self
191
192 def __neg__(self):
193 return self * -1
194
195 @_polymorphic
196 def __add__(self, other):
197 """
198 Return the sum of two linear expressions.
199 """
200 coefficients = defaultdict(Fraction, self._coefficients)
201 for symbol, coefficient in other._coefficients.items():
202 coefficients[symbol] += coefficient
203 constant = self._constant + other._constant
204 return LinExpr(coefficients, constant)
205
206 __radd__ = __add__
207
208 @_polymorphic
209 def __sub__(self, other):
210 """
211 Return the difference between two linear expressions.
212 """
213 coefficients = defaultdict(Fraction, self._coefficients)
214 for symbol, coefficient in other._coefficients.items():
215 coefficients[symbol] -= coefficient
216 constant = self._constant - other._constant
217 return LinExpr(coefficients, constant)
218
219 @_polymorphic
220 def __rsub__(self, other):
221 return other - self
222
223 def __mul__(self, other):
224 """
225 Return the product of the linear expression by a rational.
226 """
227 if isinstance(other, numbers.Rational):
228 coefficients = ((symbol, coefficient * other)
229 for symbol, coefficient in self._coefficients.items())
230 constant = self._constant * other
231 return LinExpr(coefficients, constant)
232 return NotImplemented
233
234 __rmul__ = __mul__
235
236 def __truediv__(self, other):
237 """
238 Return the quotient of the linear expression by a rational.
239 """
240 if isinstance(other, numbers.Rational):
241 coefficients = ((symbol, coefficient / other)
242 for symbol, coefficient in self._coefficients.items())
243 constant = self._constant / other
244 return LinExpr(coefficients, constant)
245 return NotImplemented
246
247 @_polymorphic
248 def __eq__(self, other):
249 """
250 Test whether two linear expressions are equal.
251 """
252 return isinstance(other, LinExpr) and \
253 self._coefficients == other._coefficients and \
254 self._constant == other._constant
255
256 def __le__(self, other):
257 from .polyhedra import Le
258 return Le(self, other)
259
260 def __lt__(self, other):
261 from .polyhedra import Lt
262 return Lt(self, other)
263
264 def __ge__(self, other):
265 from .polyhedra import Ge
266 return Ge(self, other)
267
268 def __gt__(self, other):
269 from .polyhedra import Gt
270 return Gt(self, other)
271
272 def scaleint(self):
273 """
274 Return the expression multiplied by its lowest common denominator to
275 make all values integer.
276 """
277 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
278 [value.denominator for value in self.values()])
279 return self * lcm
280
281 def subs(self, symbol, expression=None):
282 """
283 Substitute the given symbol by an expression and return the resulting
284 expression. Raise TypeError if the resulting expression is not linear.
285
286 >>> x, y = symbols('x y')
287 >>> e = x + 2*y + 1
288 >>> e.subs(y, x - 1)
289 3*x - 1
290
291 To perform multiple substitutions at once, pass a sequence or a
292 dictionary of (old, new) pairs to subs.
293
294 >>> e.subs({x: y, y: x})
295 2*x + y + 1
296 """
297 if expression is None:
298 if isinstance(symbol, Mapping):
299 symbol = symbol.items()
300 substitutions = symbol
301 else:
302 substitutions = [(symbol, expression)]
303 result = self
304 for symbol, expression in substitutions:
305 if not isinstance(symbol, Symbol):
306 raise TypeError('symbols must be Symbol instances')
307 coefficients = [(othersymbol, coefficient)
308 for othersymbol, coefficient in result._coefficients.items()
309 if othersymbol != symbol]
310 coefficient = result._coefficients.get(symbol, 0)
311 constant = result._constant
312 result = LinExpr(coefficients, constant) + coefficient*expression
313 return result
314
315 @classmethod
316 def _fromast(cls, node):
317 if isinstance(node, ast.Module) and len(node.body) == 1:
318 return cls._fromast(node.body[0])
319 elif isinstance(node, ast.Expr):
320 return cls._fromast(node.value)
321 elif isinstance(node, ast.Name):
322 return Symbol(node.id)
323 elif isinstance(node, ast.Num):
324 return Rational(node.n)
325 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
326 return -cls._fromast(node.operand)
327 elif isinstance(node, ast.BinOp):
328 left = cls._fromast(node.left)
329 right = cls._fromast(node.right)
330 if isinstance(node.op, ast.Add):
331 return left + right
332 elif isinstance(node.op, ast.Sub):
333 return left - right
334 elif isinstance(node.op, ast.Mult):
335 return left * right
336 elif isinstance(node.op, ast.Div):
337 return left / right
338 raise SyntaxError('invalid syntax')
339
340 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
341
342 @classmethod
343 def fromstring(cls, string):
344 """
345 Create an expression from a string. Raise SyntaxError if the string is
346 not properly formatted.
347 """
348 # add implicit multiplication operators, e.g. '5x' -> '5*x'
349 string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
350 tree = ast.parse(string, 'eval')
351 return cls._fromast(tree)
352
353 def __repr__(self):
354 string = ''
355 for i, (symbol, coefficient) in enumerate(self.coefficients()):
356 if coefficient == 1:
357 if i != 0:
358 string += ' + '
359 elif coefficient == -1:
360 string += '-' if i == 0 else ' - '
361 elif i == 0:
362 string += '{}*'.format(coefficient)
363 elif coefficient > 0:
364 string += ' + {}*'.format(coefficient)
365 else:
366 string += ' - {}*'.format(-coefficient)
367 string += '{}'.format(symbol)
368 constant = self.constant
369 if len(string) == 0:
370 string += '{}'.format(constant)
371 elif constant > 0:
372 string += ' + {}'.format(constant)
373 elif constant < 0:
374 string += ' - {}'.format(-constant)
375 return string
376
377 def _repr_latex_(self):
378 string = ''
379 for i, (symbol, coefficient) in enumerate(self.coefficients()):
380 if coefficient == 1:
381 if i != 0:
382 string += ' + '
383 elif coefficient == -1:
384 string += '-' if i == 0 else ' - '
385 elif i == 0:
386 string += '{}'.format(coefficient._repr_latex_().strip('$'))
387 elif coefficient > 0:
388 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
389 elif coefficient < 0:
390 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
391 string += '{}'.format(symbol._repr_latex_().strip('$'))
392 constant = self.constant
393 if len(string) == 0:
394 string += '{}'.format(constant._repr_latex_().strip('$'))
395 elif constant > 0:
396 string += ' + {}'.format(constant._repr_latex_().strip('$'))
397 elif constant < 0:
398 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
399 return '$${}$$'.format(string)
400
401 def _parenstr(self, always=False):
402 string = str(self)
403 if not always and (self.isconstant() or self.issymbol()):
404 return string
405 else:
406 return '({})'.format(string)
407
408 @classmethod
409 def fromsympy(cls, expr):
410 """
411 Create a linear expression from a sympy expression. Raise ValueError is
412 the sympy expression is not linear.
413 """
414 import sympy
415 coefficients = []
416 constant = 0
417 for symbol, coefficient in expr.as_coefficients_dict().items():
418 coefficient = Fraction(coefficient.p, coefficient.q)
419 if symbol == sympy.S.One:
420 constant = coefficient
421 elif isinstance(symbol, sympy.Symbol):
422 symbol = Symbol(symbol.name)
423 coefficients.append((symbol, coefficient))
424 else:
425 raise ValueError('non-linear expression: {!r}'.format(expr))
426 return LinExpr(coefficients, constant)
427
428 def tosympy(self):
429 """
430 Convert the linear expression to a sympy expression.
431 """
432 import sympy
433 expr = 0
434 for symbol, coefficient in self.coefficients():
435 term = coefficient * sympy.Symbol(symbol.name)
436 expr += term
437 expr += self.constant
438 return expr
439
440
441 class Symbol(LinExpr):
442 """
443 Symbols are the basic components to build expressions and constraints.
444 They correspond to mathematical variables. Symbols are instances of
445 class LinExpr and inherit its functionalities.
446
447 Two instances of Symbol are equal if they have the same name.
448 """
449
450 def __new__(cls, name):
451 """
452 Return a symbol with the name string given in argument.
453 """
454 if not isinstance(name, str):
455 raise TypeError('name must be a string')
456 self = object().__new__(cls)
457 self._name = name.strip()
458 self._coefficients = {self: Fraction(1)}
459 self._constant = Fraction(0)
460 self._symbols = (self,)
461 self._dimension = 1
462 return self
463
464 @property
465 def name(self):
466 """
467 The name of the symbol.
468 """
469 return self._name
470
471 def __hash__(self):
472 return hash(self.sortkey())
473
474 def sortkey(self):
475 """
476 Return a sorting key for the symbol. It is useful to sort a list of
477 symbols in a consistent order, as comparison functions are overridden
478 (see the documentation of class LinExpr).
479
480 >>> sort(symbols, key=Symbol.sortkey)
481 """
482 return self.name,
483
484 def issymbol(self):
485 return True
486
487 def __eq__(self, other):
488 return self.sortkey() == other.sortkey()
489
490 def asdummy(self):
491 """
492 Return a new Dummy symbol instance with the same name.
493 """
494 return Dummy(self.name)
495
496 @classmethod
497 def _fromast(cls, node):
498 if isinstance(node, ast.Module) and len(node.body) == 1:
499 return cls._fromast(node.body[0])
500 elif isinstance(node, ast.Expr):
501 return cls._fromast(node.value)
502 elif isinstance(node, ast.Name):
503 return Symbol(node.id)
504 raise SyntaxError('invalid syntax')
505
506 def __repr__(self):
507 return self.name
508
509 def _repr_latex_(self):
510 return '$${}$$'.format(self.name)
511
512 @classmethod
513 def fromsympy(cls, expr):
514 import sympy
515 if isinstance(expr, sympy.Dummy):
516 return Dummy(expr.name)
517 elif isinstance(expr, sympy.Symbol):
518 return Symbol(expr.name)
519 else:
520 raise TypeError('expr must be a sympy.Symbol instance')
521
522
523 class Dummy(Symbol):
524 """
525 A variation of Symbol in which all symbols are unique and identified by
526 an internal count index. If a name is not supplied then a string value
527 of the count index will be used. This is useful when a unique, temporary
528 variable is needed and the name of the variable used in the expression
529 is not important.
530
531 Unlike Symbol, Dummy instances with the same name are not equal:
532
533 >>> x = Symbol('x')
534 >>> x1, x2 = Dummy('x'), Dummy('x')
535 >>> x == x1
536 False
537 >>> x1 == x2
538 False
539 >>> x1 == x1
540 True
541 """
542
543 _count = 0
544
545 def __new__(cls, name=None):
546 """
547 Return a fresh dummy symbol with the name string given in argument.
548 """
549 if name is None:
550 name = 'Dummy_{}'.format(Dummy._count)
551 elif not isinstance(name, str):
552 raise TypeError('name must be a string')
553 self = object().__new__(cls)
554 self._index = Dummy._count
555 self._name = name.strip()
556 self._coefficients = {self: Fraction(1)}
557 self._constant = Fraction(0)
558 self._symbols = (self,)
559 self._dimension = 1
560 Dummy._count += 1
561 return self
562
563 def __hash__(self):
564 return hash(self.sortkey())
565
566 def sortkey(self):
567 return self._name, self._index
568
569 def __repr__(self):
570 return '_{}'.format(self.name)
571
572 def _repr_latex_(self):
573 return '$${}_{{{}}}$$'.format(self.name, self._index)
574
575
576 def symbols(names):
577 """
578 This function returns a tuple of symbols whose names are taken from a comma
579 or whitespace delimited string, or a sequence of strings. It is useful to
580 define several symbols at once.
581
582 >>> x, y = symbols('x y')
583 >>> x, y = symbols('x, y')
584 >>> x, y = symbols(['x', 'y'])
585 """
586 if isinstance(names, str):
587 names = names.replace(',', ' ').split()
588 return tuple(Symbol(name) for name in names)
589
590
591 class Rational(LinExpr, Fraction):
592 """
593 A particular case of linear expressions are rational values, i.e. linear
594 expressions consisting only of a constant term, with no symbol. They are
595 implemented by the Rational class, that inherits from both LinExpr and
596 fractions.Fraction classes.
597 """
598
599 def __new__(cls, numerator=0, denominator=None):
600 self = object().__new__(cls)
601 self._coefficients = {}
602 self._constant = Fraction(numerator, denominator)
603 self._symbols = ()
604 self._dimension = 0
605 self._numerator = self._constant.numerator
606 self._denominator = self._constant.denominator
607 return self
608
609 def __hash__(self):
610 return Fraction.__hash__(self)
611
612 @property
613 def constant(self):
614 return self
615
616 def isconstant(self):
617 return True
618
619 def __bool__(self):
620 return Fraction.__bool__(self)
621
622 def __repr__(self):
623 if self.denominator == 1:
624 return '{!r}'.format(self.numerator)
625 else:
626 return '{!r}/{!r}'.format(self.numerator, self.denominator)
627
628 def _repr_latex_(self):
629 if self.denominator == 1:
630 return '$${}$$'.format(self.numerator)
631 elif self.numerator < 0:
632 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
633 self.denominator)
634 else:
635 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
636 self.denominator)
637
638 @classmethod
639 def fromsympy(cls, expr):
640 import sympy
641 if isinstance(expr, sympy.Rational):
642 return Rational(expr.p, expr.q)
643 elif isinstance(expr, numbers.Rational):
644 return Rational(expr)
645 else:
646 raise TypeError('expr must be a sympy.Rational instance')