Rename Expression class into LinExpr
[linpy.git] / linpy / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of LinPy.
4 #
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import OrderedDict, defaultdict, Mapping
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'LinExpr',
29 'Symbol', 'Dummy', 'symbols',
30 'Rational',
31 ]
32
33
34 def _polymorphic(func):
35 @functools.wraps(func)
36 def wrapper(left, right):
37 if isinstance(right, LinExpr):
38 return func(left, right)
39 elif isinstance(right, numbers.Rational):
40 right = Rational(right)
41 return func(left, right)
42 return NotImplemented
43 return wrapper
44
45
46 class LinExpr:
47 """
48 This class implements linear expressions.
49 """
50
51 def __new__(cls, coefficients=None, constant=0):
52 """
53 Create a new expression.
54 """
55 if isinstance(coefficients, str):
56 if constant != 0:
57 raise TypeError('too many arguments')
58 return LinExpr.fromstring(coefficients)
59 if coefficients is None:
60 return Rational(constant)
61 if isinstance(coefficients, Mapping):
62 coefficients = coefficients.items()
63 coefficients = list(coefficients)
64 for symbol, coefficient in coefficients:
65 if not isinstance(symbol, Symbol):
66 raise TypeError('symbols must be Symbol instances')
67 if not isinstance(coefficient, numbers.Rational):
68 raise TypeError('coefficients must be rational numbers')
69 if not isinstance(constant, numbers.Rational):
70 raise TypeError('constant must be a rational number')
71 if len(coefficients) == 0:
72 return Rational(constant)
73 if len(coefficients) == 1 and constant == 0:
74 symbol, coefficient = coefficients[0]
75 if coefficient == 1:
76 return symbol
77 coefficients = [(symbol, Fraction(coefficient))
78 for symbol, coefficient in coefficients if coefficient != 0]
79 coefficients.sort(key=lambda item: item[0].sortkey())
80 self = object().__new__(cls)
81 self._coefficients = OrderedDict(coefficients)
82 self._constant = Fraction(constant)
83 self._symbols = tuple(self._coefficients)
84 self._dimension = len(self._symbols)
85 return self
86
87 def coefficient(self, symbol):
88 """
89 Return the coefficient value of the given symbol.
90 """
91 if not isinstance(symbol, Symbol):
92 raise TypeError('symbol must be a Symbol instance')
93 return Rational(self._coefficients.get(symbol, 0))
94
95 __getitem__ = coefficient
96
97 def coefficients(self):
98 """
99 Return a list of the coefficients of an expression
100 """
101 for symbol, coefficient in self._coefficients.items():
102 yield symbol, Rational(coefficient)
103
104 @property
105 def constant(self):
106 """
107 Return the constant value of an expression.
108 """
109 return Rational(self._constant)
110
111 @property
112 def symbols(self):
113 """
114 Return a list of symbols in an expression.
115 """
116 return self._symbols
117
118 @property
119 def dimension(self):
120 """
121 Create and return a new linear expression from a string or a list of coefficients and a constant.
122 """
123 return self._dimension
124
125 def __hash__(self):
126 return hash((tuple(self._coefficients.items()), self._constant))
127
128 def isconstant(self):
129 """
130 Return true if an expression is a constant.
131 """
132 return False
133
134 def issymbol(self):
135 """
136 Return true if an expression is a symbol.
137 """
138 return False
139
140 def values(self):
141 """
142 Return the coefficient and constant values of an expression.
143 """
144 for coefficient in self._coefficients.values():
145 yield Rational(coefficient)
146 yield Rational(self._constant)
147
148 def __bool__(self):
149 return True
150
151 def __pos__(self):
152 return self
153
154 def __neg__(self):
155 return self * -1
156
157 @_polymorphic
158 def __add__(self, other):
159 """
160 Return the sum of two expressions.
161 """
162 coefficients = defaultdict(Fraction, self._coefficients)
163 for symbol, coefficient in other._coefficients.items():
164 coefficients[symbol] += coefficient
165 constant = self._constant + other._constant
166 return LinExpr(coefficients, constant)
167
168 __radd__ = __add__
169
170 @_polymorphic
171 def __sub__(self, other):
172 """
173 Return the difference between two expressions.
174 """
175 coefficients = defaultdict(Fraction, self._coefficients)
176 for symbol, coefficient in other._coefficients.items():
177 coefficients[symbol] -= coefficient
178 constant = self._constant - other._constant
179 return LinExpr(coefficients, constant)
180
181 @_polymorphic
182 def __rsub__(self, other):
183 return other - self
184
185 def __mul__(self, other):
186 """
187 Return the product of two expressions if other is a rational number.
188 """
189 if isinstance(other, numbers.Rational):
190 coefficients = ((symbol, coefficient * other)
191 for symbol, coefficient in self._coefficients.items())
192 constant = self._constant * other
193 return LinExpr(coefficients, constant)
194 return NotImplemented
195
196 __rmul__ = __mul__
197
198 def __truediv__(self, other):
199 if isinstance(other, numbers.Rational):
200 coefficients = ((symbol, coefficient / other)
201 for symbol, coefficient in self._coefficients.items())
202 constant = self._constant / other
203 return LinExpr(coefficients, constant)
204 return NotImplemented
205
206 @_polymorphic
207 def __eq__(self, other):
208 """
209 Test whether two expressions are equal
210 """
211 return isinstance(other, LinExpr) and \
212 self._coefficients == other._coefficients and \
213 self._constant == other._constant
214
215 def __le__(self, other):
216 from .polyhedra import Le
217 return Le(self, other)
218
219 def __lt__(self, other):
220 from .polyhedra import Lt
221 return Lt(self, other)
222
223 def __ge__(self, other):
224 from .polyhedra import Ge
225 return Ge(self, other)
226
227 def __gt__(self, other):
228 from .polyhedra import Gt
229 return Gt(self, other)
230
231 def scaleint(self):
232 """
233 Multiply an expression by a scalar to make all coefficients integer values.
234 """
235 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
236 [value.denominator for value in self.values()])
237 return self * lcm
238
239 def subs(self, symbol, expression=None):
240 """
241 Subsitute symbol by expression in equations and return the resulting
242 expression.
243 """
244 if expression is None:
245 if isinstance(symbol, Mapping):
246 symbol = symbol.items()
247 substitutions = symbol
248 else:
249 substitutions = [(symbol, expression)]
250 result = self
251 for symbol, expression in substitutions:
252 if not isinstance(symbol, Symbol):
253 raise TypeError('symbols must be Symbol instances')
254 coefficients = [(othersymbol, coefficient)
255 for othersymbol, coefficient in result._coefficients.items()
256 if othersymbol != symbol]
257 coefficient = result._coefficients.get(symbol, 0)
258 constant = result._constant
259 result = LinExpr(coefficients, constant) + coefficient*expression
260 return result
261
262 @classmethod
263 def _fromast(cls, node):
264 if isinstance(node, ast.Module) and len(node.body) == 1:
265 return cls._fromast(node.body[0])
266 elif isinstance(node, ast.Expr):
267 return cls._fromast(node.value)
268 elif isinstance(node, ast.Name):
269 return Symbol(node.id)
270 elif isinstance(node, ast.Num):
271 return Rational(node.n)
272 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
273 return -cls._fromast(node.operand)
274 elif isinstance(node, ast.BinOp):
275 left = cls._fromast(node.left)
276 right = cls._fromast(node.right)
277 if isinstance(node.op, ast.Add):
278 return left + right
279 elif isinstance(node.op, ast.Sub):
280 return left - right
281 elif isinstance(node.op, ast.Mult):
282 return left * right
283 elif isinstance(node.op, ast.Div):
284 return left / right
285 raise SyntaxError('invalid syntax')
286
287 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
288
289 @classmethod
290 def fromstring(cls, string):
291 """
292 Create an expression from a string.
293 """
294 # add implicit multiplication operators, e.g. '5x' -> '5*x'
295 string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
296 tree = ast.parse(string, 'eval')
297 return cls._fromast(tree)
298
299 def __repr__(self):
300 string = ''
301 for i, (symbol, coefficient) in enumerate(self.coefficients()):
302 if coefficient == 1:
303 if i != 0:
304 string += ' + '
305 elif coefficient == -1:
306 string += '-' if i == 0 else ' - '
307 elif i == 0:
308 string += '{}*'.format(coefficient)
309 elif coefficient > 0:
310 string += ' + {}*'.format(coefficient)
311 else:
312 string += ' - {}*'.format(-coefficient)
313 string += '{}'.format(symbol)
314 constant = self.constant
315 if len(string) == 0:
316 string += '{}'.format(constant)
317 elif constant > 0:
318 string += ' + {}'.format(constant)
319 elif constant < 0:
320 string += ' - {}'.format(-constant)
321 return string
322
323 def _repr_latex_(self):
324 string = ''
325 for i, (symbol, coefficient) in enumerate(self.coefficients()):
326 if coefficient == 1:
327 if i != 0:
328 string += ' + '
329 elif coefficient == -1:
330 string += '-' if i == 0 else ' - '
331 elif i == 0:
332 string += '{}'.format(coefficient._repr_latex_().strip('$'))
333 elif coefficient > 0:
334 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
335 elif coefficient < 0:
336 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
337 string += '{}'.format(symbol._repr_latex_().strip('$'))
338 constant = self.constant
339 if len(string) == 0:
340 string += '{}'.format(constant._repr_latex_().strip('$'))
341 elif constant > 0:
342 string += ' + {}'.format(constant._repr_latex_().strip('$'))
343 elif constant < 0:
344 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
345 return '$${}$$'.format(string)
346
347 def _parenstr(self, always=False):
348 string = str(self)
349 if not always and (self.isconstant() or self.issymbol()):
350 return string
351 else:
352 return '({})'.format(string)
353
354 @classmethod
355 def fromsympy(cls, expr):
356 """
357 Convert sympy object to an expression.
358 """
359 import sympy
360 coefficients = []
361 constant = 0
362 for symbol, coefficient in expr.as_coefficients_dict().items():
363 coefficient = Fraction(coefficient.p, coefficient.q)
364 if symbol == sympy.S.One:
365 constant = coefficient
366 elif isinstance(symbol, sympy.Symbol):
367 symbol = Symbol(symbol.name)
368 coefficients.append((symbol, coefficient))
369 else:
370 raise ValueError('non-linear expression: {!r}'.format(expr))
371 return LinExpr(coefficients, constant)
372
373 def tosympy(self):
374 """
375 Return an expression as a sympy object.
376 """
377 import sympy
378 expr = 0
379 for symbol, coefficient in self.coefficients():
380 term = coefficient * sympy.Symbol(symbol.name)
381 expr += term
382 expr += self.constant
383 return expr
384
385
386 class Symbol(LinExpr):
387
388 def __new__(cls, name):
389 """
390 Create and return a symbol from a string.
391 """
392 if not isinstance(name, str):
393 raise TypeError('name must be a string')
394 self = object().__new__(cls)
395 self._name = name.strip()
396 self._coefficients = {self: Fraction(1)}
397 self._constant = Fraction(0)
398 self._symbols = (self,)
399 self._dimension = 1
400 return self
401
402 @property
403 def name(self):
404 return self._name
405
406 def __hash__(self):
407 return hash(self.sortkey())
408
409 def sortkey(self):
410 return self.name,
411
412 def issymbol(self):
413 return True
414
415 def __eq__(self, other):
416 return self.sortkey() == other.sortkey()
417
418 def asdummy(self):
419 """
420 Return a symbol as a Dummy Symbol.
421 """
422 return Dummy(self.name)
423
424 @classmethod
425 def _fromast(cls, node):
426 if isinstance(node, ast.Module) and len(node.body) == 1:
427 return cls._fromast(node.body[0])
428 elif isinstance(node, ast.Expr):
429 return cls._fromast(node.value)
430 elif isinstance(node, ast.Name):
431 return Symbol(node.id)
432 raise SyntaxError('invalid syntax')
433
434 def __repr__(self):
435 return self.name
436
437 def _repr_latex_(self):
438 return '$${}$$'.format(self.name)
439
440 @classmethod
441 def fromsympy(cls, expr):
442 import sympy
443 if isinstance(expr, sympy.Dummy):
444 return Dummy(expr.name)
445 elif isinstance(expr, sympy.Symbol):
446 return Symbol(expr.name)
447 else:
448 raise TypeError('expr must be a sympy.Symbol instance')
449
450
451 class Dummy(Symbol):
452 """
453 This class returns a dummy symbol to ensure that no variables are repeated in an expression
454 """
455 _count = 0
456
457 def __new__(cls, name=None):
458 """
459 Create and return a new dummy symbol.
460 """
461 if name is None:
462 name = 'Dummy_{}'.format(Dummy._count)
463 elif not isinstance(name, str):
464 raise TypeError('name must be a string')
465 self = object().__new__(cls)
466 self._index = Dummy._count
467 self._name = name.strip()
468 self._coefficients = {self: Fraction(1)}
469 self._constant = Fraction(0)
470 self._symbols = (self,)
471 self._dimension = 1
472 Dummy._count += 1
473 return self
474
475 def __hash__(self):
476 return hash(self.sortkey())
477
478 def sortkey(self):
479 return self._name, self._index
480
481 def __repr__(self):
482 return '_{}'.format(self.name)
483
484 def _repr_latex_(self):
485 return '$${}_{{{}}}$$'.format(self.name, self._index)
486
487
488 def symbols(names):
489 """
490 Transform strings into instances of the Symbol class
491 """
492 if isinstance(names, str):
493 names = names.replace(',', ' ').split()
494 return tuple(Symbol(name) for name in names)
495
496
497 class Rational(LinExpr, Fraction):
498 """
499 This class represents integers and rational numbers of any size.
500 """
501
502 def __new__(cls, numerator=0, denominator=None):
503 self = object().__new__(cls)
504 self._coefficients = {}
505 self._constant = Fraction(numerator, denominator)
506 self._symbols = ()
507 self._dimension = 0
508 self._numerator = self._constant.numerator
509 self._denominator = self._constant.denominator
510 return self
511
512 def __hash__(self):
513 return Fraction.__hash__(self)
514
515 @property
516 def constant(self):
517 """
518 Return rational as a constant.
519 """
520 return self
521
522 def isconstant(self):
523 """
524 Test whether a value is a constant.
525 """
526 return True
527
528 def __bool__(self):
529 return Fraction.__bool__(self)
530
531 def __repr__(self):
532 if self.denominator == 1:
533 return '{!r}'.format(self.numerator)
534 else:
535 return '{!r}/{!r}'.format(self.numerator, self.denominator)
536
537 def _repr_latex_(self):
538 if self.denominator == 1:
539 return '$${}$$'.format(self.numerator)
540 elif self.numerator < 0:
541 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
542 self.denominator)
543 else:
544 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
545 self.denominator)
546
547 @classmethod
548 def fromsympy(cls, expr):
549 """
550 Create a rational object from a sympy expression
551 """
552 import sympy
553 if isinstance(expr, sympy.Rational):
554 return Rational(expr.p, expr.q)
555 elif isinstance(expr, numbers.Rational):
556 return Rational(expr)
557 else:
558 raise TypeError('expr must be a sympy.Rational instance')