Fix error message
[linpy.git] / linpy / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of LinPy.
4 #
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import defaultdict, Mapping, OrderedDict
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'Dummy',
29 'LinExpr',
30 'Rational',
31 'Symbol',
32 'symbols',
33 ]
34
35
36 def _polymorphic(func):
37 @functools.wraps(func)
38 def wrapper(left, right):
39 if isinstance(right, LinExpr):
40 return func(left, right)
41 elif isinstance(right, numbers.Rational):
42 right = Rational(right)
43 return func(left, right)
44 return NotImplemented
45 return wrapper
46
47
48 class LinExpr:
49 """
50 A linear expression consists of a list of coefficient-variable pairs
51 that capture the linear terms, plus a constant term. Linear expressions
52 are used to build constraints. They are temporary objects that typically
53 have short lifespans.
54
55 Linear expressions are generally built using overloaded operators. For
56 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
57
58 LinExpr instances are hashable, and should be treated as immutable.
59 """
60
61 def __new__(cls, coefficients=None, constant=0):
62 """
63 Return a linear expression from a dictionary or a sequence, that maps
64 symbols to their coefficients, and a constant term. The coefficients
65 and the constant term must be rational numbers.
66
67 For example, the linear expression x + 2*y + 1 can be constructed using
68 one of the following instructions:
69
70 >>> x, y = symbols('x y')
71 >>> LinExpr({x: 1, y: 2}, 1)
72 >>> LinExpr([(x, 1), (y, 2)], 1)
73
74 However, it may be easier to use overloaded operators:
75
76 >>> x, y = symbols('x y')
77 >>> x + 2*y + 1
78
79 Alternatively, linear expressions can be constructed from a string:
80
81 >>> LinExpr('x + 2y + 1')
82
83 A linear expression with a single symbol of coefficient 1 and no
84 constant term is automatically subclassed as a Symbol instance. A
85 linear expression with no symbol, only a constant term, is
86 automatically subclassed as a Rational instance.
87 """
88 if isinstance(coefficients, str):
89 if constant != 0:
90 raise TypeError('too many arguments')
91 return LinExpr.fromstring(coefficients)
92 if coefficients is None:
93 return Rational(constant)
94 if isinstance(coefficients, Mapping):
95 coefficients = coefficients.items()
96 coefficients = list(coefficients)
97 for symbol, coefficient in coefficients:
98 if not isinstance(symbol, Symbol):
99 raise TypeError('symbols must be Symbol instances')
100 if not isinstance(coefficient, numbers.Rational):
101 raise TypeError('coefficients must be rational numbers')
102 if not isinstance(constant, numbers.Rational):
103 raise TypeError('constant must be a rational number')
104 if len(coefficients) == 0:
105 return Rational(constant)
106 if len(coefficients) == 1 and constant == 0:
107 symbol, coefficient = coefficients[0]
108 if coefficient == 1:
109 return symbol
110 coefficients = [(symbol_, Fraction(coefficient_))
111 for symbol_, coefficient_ in coefficients
112 if coefficient_ != 0]
113 coefficients.sort(key=lambda item: item[0].sortkey())
114 self = object().__new__(cls)
115 self._coefficients = OrderedDict(coefficients)
116 self._constant = Fraction(constant)
117 self._symbols = tuple(self._coefficients)
118 self._dimension = len(self._symbols)
119 return self
120
121 def coefficient(self, symbol):
122 """
123 Return the coefficient value of the given symbol, or 0 if the symbol
124 does not appear in the expression.
125 """
126 if not isinstance(symbol, Symbol):
127 raise TypeError('symbol must be a Symbol instance')
128 return self._coefficients.get(symbol, Fraction(0))
129
130 __getitem__ = coefficient
131
132 def coefficients(self):
133 """
134 Iterate over the pairs (symbol, value) of linear terms in the
135 expression. The constant term is ignored.
136 """
137 yield from self._coefficients.items()
138
139 @property
140 def constant(self):
141 """
142 The constant term of the expression.
143 """
144 return self._constant
145
146 @property
147 def symbols(self):
148 """
149 The tuple of symbols present in the expression, sorted according to
150 Symbol.sortkey().
151 """
152 return self._symbols
153
154 @property
155 def dimension(self):
156 """
157 The dimension of the expression, i.e. the number of symbols present in
158 it.
159 """
160 return self._dimension
161
162 def __hash__(self):
163 return hash((tuple(self._coefficients.items()), self._constant))
164
165 def isconstant(self):
166 """
167 Return True if the expression only consists of a constant term. In this
168 case, it is a Rational instance.
169 """
170 return False
171
172 def issymbol(self):
173 """
174 Return True if an expression only consists of a symbol with coefficient
175 1. In this case, it is a Symbol instance.
176 """
177 return False
178
179 def values(self):
180 """
181 Iterate over the coefficient values in the expression, and the constant
182 term.
183 """
184 yield from self._coefficients.values()
185 yield self._constant
186
187 def __bool__(self):
188 return True
189
190 def __pos__(self):
191 return self
192
193 def __neg__(self):
194 return self * -1
195
196 @_polymorphic
197 def __add__(self, other):
198 """
199 Return the sum of two linear expressions.
200 """
201 coefficients = defaultdict(Fraction, self._coefficients)
202 for symbol, coefficient in other._coefficients.items():
203 coefficients[symbol] += coefficient
204 constant = self._constant + other._constant
205 return LinExpr(coefficients, constant)
206
207 __radd__ = __add__
208
209 @_polymorphic
210 def __sub__(self, other):
211 """
212 Return the difference between two linear expressions.
213 """
214 coefficients = defaultdict(Fraction, self._coefficients)
215 for symbol, coefficient in other._coefficients.items():
216 coefficients[symbol] -= coefficient
217 constant = self._constant - other._constant
218 return LinExpr(coefficients, constant)
219
220 @_polymorphic
221 def __rsub__(self, other):
222 return other - self
223
224 def __mul__(self, other):
225 """
226 Return the product of the linear expression by a rational.
227 """
228 if isinstance(other, numbers.Rational):
229 coefficients = (
230 (symbol, coefficient * other)
231 for symbol, coefficient in self._coefficients.items())
232 constant = self._constant * other
233 return LinExpr(coefficients, constant)
234 return NotImplemented
235
236 __rmul__ = __mul__
237
238 def __truediv__(self, other):
239 """
240 Return the quotient of the linear expression by a rational.
241 """
242 if isinstance(other, numbers.Rational):
243 coefficients = (
244 (symbol, coefficient / other)
245 for symbol, coefficient in self._coefficients.items())
246 constant = self._constant / other
247 return LinExpr(coefficients, constant)
248 return NotImplemented
249
250 @_polymorphic
251 def __eq__(self, other):
252 """
253 Test whether two linear expressions are equal. Unlike methods
254 LinExpr.__lt__(), LinExpr.__le__(), LinExpr.__ge__(), LinExpr.__gt__(),
255 the result is a boolean value, not a polyhedron. To express that two
256 linear expressions are equal or not equal, use functions Eq() and Ne()
257 instead.
258 """
259 return self._coefficients == other._coefficients and \
260 self._constant == other._constant
261
262 @_polymorphic
263 def __lt__(self, other):
264 from .polyhedra import Polyhedron
265 return Polyhedron([], [other - self - 1])
266
267 @_polymorphic
268 def __le__(self, other):
269 from .polyhedra import Polyhedron
270 return Polyhedron([], [other - self])
271
272 @_polymorphic
273 def __ge__(self, other):
274 from .polyhedra import Polyhedron
275 return Polyhedron([], [self - other])
276
277 @_polymorphic
278 def __gt__(self, other):
279 from .polyhedra import Polyhedron
280 return Polyhedron([], [self - other - 1])
281
282 def scaleint(self):
283 """
284 Return the expression multiplied by its lowest common denominator to
285 make all values integer.
286 """
287 lcd = functools.reduce(lambda a, b: a*b // gcd(a, b),
288 [value.denominator for value in self.values()])
289 return self * lcd
290
291 def subs(self, symbol, expression=None):
292 """
293 Substitute the given symbol by an expression and return the resulting
294 expression. Raise TypeError if the resulting expression is not linear.
295
296 >>> x, y = symbols('x y')
297 >>> e = x + 2*y + 1
298 >>> e.subs(y, x - 1)
299 3*x - 1
300
301 To perform multiple substitutions at once, pass a sequence or a
302 dictionary of (old, new) pairs to subs.
303
304 >>> e.subs({x: y, y: x})
305 2*x + y + 1
306 """
307 if expression is None:
308 substitutions = dict(symbol)
309 else:
310 substitutions = {symbol: expression}
311 for symbol in substitutions:
312 if not isinstance(symbol, Symbol):
313 raise TypeError('symbols must be Symbol instances')
314 result = Rational(self._constant)
315 for symbol, coefficient in self._coefficients.items():
316 expression = substitutions.get(symbol, symbol)
317 result += coefficient * expression
318 return result
319
320 @classmethod
321 def _fromast(cls, node):
322 if isinstance(node, ast.Module) and len(node.body) == 1:
323 return cls._fromast(node.body[0])
324 elif isinstance(node, ast.Expr):
325 return cls._fromast(node.value)
326 elif isinstance(node, ast.Name):
327 return Symbol(node.id)
328 elif isinstance(node, ast.Num):
329 return Rational(node.n)
330 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
331 return -cls._fromast(node.operand)
332 elif isinstance(node, ast.BinOp):
333 left = cls._fromast(node.left)
334 right = cls._fromast(node.right)
335 if isinstance(node.op, ast.Add):
336 return left + right
337 elif isinstance(node.op, ast.Sub):
338 return left - right
339 elif isinstance(node.op, ast.Mult):
340 return left * right
341 elif isinstance(node.op, ast.Div):
342 return left / right
343 raise SyntaxError('invalid syntax')
344
345 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()')
346
347 @classmethod
348 def fromstring(cls, string):
349 """
350 Create an expression from a string. Raise SyntaxError if the string is
351 not properly formatted.
352 """
353 # Add implicit multiplication operators, e.g. '5x' -> '5*x'.
354 string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
355 tree = ast.parse(string, 'eval')
356 expression = cls._fromast(tree)
357 if not isinstance(expression, cls):
358 raise SyntaxError('invalid syntax')
359 return expression
360
361 def __repr__(self):
362 string = ''
363 for i, (symbol, coefficient) in enumerate(self.coefficients()):
364 if coefficient == 1:
365 if i != 0:
366 string += ' + '
367 elif coefficient == -1:
368 string += '-' if i == 0 else ' - '
369 elif i == 0:
370 string += '{}*'.format(coefficient)
371 elif coefficient > 0:
372 string += ' + {}*'.format(coefficient)
373 else:
374 string += ' - {}*'.format(-coefficient)
375 string += '{}'.format(symbol)
376 constant = self.constant
377 if len(string) == 0:
378 string += '{}'.format(constant)
379 elif constant > 0:
380 string += ' + {}'.format(constant)
381 elif constant < 0:
382 string += ' - {}'.format(-constant)
383 return string
384
385 def _parenstr(self, always=False):
386 string = str(self)
387 if not always and (self.isconstant() or self.issymbol()):
388 return string
389 else:
390 return '({})'.format(string)
391
392 @classmethod
393 def fromsympy(cls, expression):
394 """
395 Create a linear expression from a SymPy expression. Raise TypeError is
396 the sympy expression is not linear.
397 """
398 import sympy
399 coefficients = []
400 constant = 0
401 for symbol, coefficient in expression.as_coefficients_dict().items():
402 coefficient = Fraction(coefficient.p, coefficient.q)
403 if symbol == sympy.S.One:
404 constant = coefficient
405 elif isinstance(symbol, sympy.Dummy):
406 # We cannot properly convert dummy symbols with respect to
407 # symbol equalities.
408 raise TypeError('cannot convert dummy symbols')
409 elif isinstance(symbol, sympy.Symbol):
410 symbol = Symbol(symbol.name)
411 coefficients.append((symbol, coefficient))
412 else:
413 raise TypeError('non-linear expression: {!r}'.format(
414 expression))
415 expression = LinExpr(coefficients, constant)
416 if not isinstance(expression, cls):
417 raise TypeError('cannot convert to a {} instance'.format(
418 cls.__name__))
419 return expression
420
421 def tosympy(self):
422 """
423 Convert the linear expression to a SymPy expression.
424 """
425 import sympy
426 expression = 0
427 for symbol, coefficient in self.coefficients():
428 term = coefficient * sympy.Symbol(symbol.name)
429 expression += term
430 expression += self.constant
431 return expression
432
433
434 class Symbol(LinExpr):
435 """
436 Symbols are the basic components to build expressions and constraints.
437 They correspond to mathematical variables. Symbols are instances of
438 class LinExpr and inherit its functionalities.
439
440 Two instances of Symbol are equal if they have the same name.
441 """
442
443 __slots__ = (
444 '_name',
445 '_constant',
446 '_symbols',
447 '_dimension',
448 )
449
450 def __new__(cls, name):
451 """
452 Return a symbol with the name string given in argument.
453 """
454 if not isinstance(name, str):
455 raise TypeError('name must be a string')
456 node = ast.parse(name)
457 try:
458 name = node.body[0].value.id
459 except (AttributeError, SyntaxError):
460 raise SyntaxError('invalid syntax')
461 self = object().__new__(cls)
462 self._name = name
463 self._constant = Fraction(0)
464 self._symbols = (self,)
465 self._dimension = 1
466 return self
467
468 @property
469 def _coefficients(self):
470 # This is not implemented as an attribute, because __hash__ is not
471 # callable in __new__ in class Dummy.
472 return {self: Fraction(1)}
473
474 @property
475 def name(self):
476 """
477 The name of the symbol.
478 """
479 return self._name
480
481 def __hash__(self):
482 return hash(self.sortkey())
483
484 def sortkey(self):
485 """
486 Return a sorting key for the symbol. It is useful to sort a list of
487 symbols in a consistent order, as comparison functions are overridden
488 (see the documentation of class LinExpr).
489
490 >>> sort(symbols, key=Symbol.sortkey)
491 """
492 return self.name,
493
494 def issymbol(self):
495 return True
496
497 def __eq__(self, other):
498 if isinstance(other, Symbol):
499 return self.sortkey() == other.sortkey()
500 return NotImplemented
501
502 def asdummy(self):
503 """
504 Return a new Dummy symbol instance with the same name.
505 """
506 return Dummy(self.name)
507
508 def __repr__(self):
509 return self.name
510
511
512 def symbols(names):
513 """
514 This function returns a tuple of symbols whose names are taken from a comma
515 or whitespace delimited string, or a sequence of strings. It is useful to
516 define several symbols at once.
517
518 >>> x, y = symbols('x y')
519 >>> x, y = symbols('x, y')
520 >>> x, y = symbols(['x', 'y'])
521 """
522 if isinstance(names, str):
523 names = names.replace(',', ' ').split()
524 return tuple(Symbol(name) for name in names)
525
526
527 class Dummy(Symbol):
528 """
529 A variation of Symbol in which all symbols are unique and identified by
530 an internal count index. If a name is not supplied then a string value
531 of the count index will be used. This is useful when a unique, temporary
532 variable is needed and the name of the variable used in the expression
533 is not important.
534
535 Unlike Symbol, Dummy instances with the same name are not equal:
536
537 >>> x = Symbol('x')
538 >>> x1, x2 = Dummy('x'), Dummy('x')
539 >>> x == x1
540 False
541 >>> x1 == x2
542 False
543 >>> x1 == x1
544 True
545 """
546
547 _count = 0
548
549 def __new__(cls, name=None):
550 """
551 Return a fresh dummy symbol with the name string given in argument.
552 """
553 if name is None:
554 name = 'Dummy_{}'.format(Dummy._count)
555 self = super().__new__(cls, name)
556 self._index = Dummy._count
557 Dummy._count += 1
558 return self
559
560 def __hash__(self):
561 return hash(self.sortkey())
562
563 def sortkey(self):
564 return self._name, self._index
565
566 def __repr__(self):
567 return '_{}'.format(self.name)
568
569
570 class Rational(LinExpr, Fraction):
571 """
572 A particular case of linear expressions are rational values, i.e. linear
573 expressions consisting only of a constant term, with no symbol. They are
574 implemented by the Rational class, that inherits from both LinExpr and
575 fractions.Fraction classes.
576 """
577
578 __slots__ = (
579 '_coefficients',
580 '_constant',
581 '_symbols',
582 '_dimension',
583 ) + Fraction.__slots__
584
585 def __new__(cls, numerator=0, denominator=None):
586 self = object().__new__(cls)
587 self._coefficients = {}
588 self._constant = Fraction(numerator, denominator)
589 self._symbols = ()
590 self._dimension = 0
591 self._numerator = self._constant.numerator
592 self._denominator = self._constant.denominator
593 return self
594
595 def __hash__(self):
596 return Fraction.__hash__(self)
597
598 @property
599 def constant(self):
600 return self
601
602 def isconstant(self):
603 return True
604
605 def __bool__(self):
606 return Fraction.__bool__(self)
607
608 def __repr__(self):
609 if self.denominator == 1:
610 return '{!r}'.format(self.numerator)
611 else:
612 return '{!r}/{!r}'.format(self.numerator, self.denominator)