bd3ad5a61b125903ca9006ee88cdce8e9aa0a8e4
[linpy.git] / pypol / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of Linpy.
4 #
5 # Linpy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # Linpy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with Linpy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import OrderedDict, defaultdict, Mapping
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'Expression',
29 'Symbol', 'Dummy', 'symbols',
30 'Rational',
31 ]
32
33
34 def _polymorphic(func):
35 @functools.wraps(func)
36 def wrapper(left, right):
37 if isinstance(right, Expression):
38 return func(left, right)
39 elif isinstance(right, numbers.Rational):
40 right = Rational(right)
41 return func(left, right)
42 return NotImplemented
43 return wrapper
44
45
46 class Expression:
47 """
48 This class implements linear expressions.
49 """
50
51 def __new__(cls, coefficients=None, constant=0):
52 if isinstance(coefficients, str):
53 if constant != 0:
54 raise TypeError('too many arguments')
55 return Expression.fromstring(coefficients)
56 if coefficients is None:
57 return Rational(constant)
58 if isinstance(coefficients, Mapping):
59 coefficients = coefficients.items()
60 coefficients = list(coefficients)
61 for symbol, coefficient in coefficients:
62 if not isinstance(symbol, Symbol):
63 raise TypeError('symbols must be Symbol instances')
64 if not isinstance(coefficient, numbers.Rational):
65 raise TypeError('coefficients must be rational numbers')
66 if not isinstance(constant, numbers.Rational):
67 raise TypeError('constant must be a rational number')
68 if len(coefficients) == 0:
69 return Rational(constant)
70 if len(coefficients) == 1 and constant == 0:
71 symbol, coefficient = coefficients[0]
72 if coefficient == 1:
73 return symbol
74 coefficients = [(symbol, Fraction(coefficient))
75 for symbol, coefficient in coefficients if coefficient != 0]
76 coefficients.sort(key=lambda item: item[0].sortkey())
77 self = object().__new__(cls)
78 self._coefficients = OrderedDict(coefficients)
79 self._constant = Fraction(constant)
80 self._symbols = tuple(self._coefficients)
81 self._dimension = len(self._symbols)
82 return self
83
84 def coefficient(self, symbol):
85 if not isinstance(symbol, Symbol):
86 raise TypeError('symbol must be a Symbol instance')
87 return Rational(self._coefficients.get(symbol, 0))
88
89 __getitem__ = coefficient
90
91 def coefficients(self):
92 for symbol, coefficient in self._coefficients.items():
93 yield symbol, Rational(coefficient)
94
95 @property
96 def constant(self):
97 return Rational(self._constant)
98
99 @property
100 def symbols(self):
101 return self._symbols
102
103 @property
104 def dimension(self):
105 return self._dimension
106
107 def __hash__(self):
108 return hash((tuple(self._coefficients.items()), self._constant))
109
110 def isconstant(self):
111 return False
112
113 def issymbol(self):
114 return False
115
116 def values(self):
117 for coefficient in self._coefficients.values():
118 yield Rational(coefficient)
119 yield Rational(self._constant)
120
121 def __bool__(self):
122 return True
123
124 def __pos__(self):
125 return self
126
127 def __neg__(self):
128 return self * -1
129
130 @_polymorphic
131 def __add__(self, other):
132 coefficients = defaultdict(Fraction, self._coefficients)
133 for symbol, coefficient in other._coefficients.items():
134 coefficients[symbol] += coefficient
135 constant = self._constant + other._constant
136 return Expression(coefficients, constant)
137
138 __radd__ = __add__
139
140 @_polymorphic
141 def __sub__(self, other):
142 coefficients = defaultdict(Fraction, self._coefficients)
143 for symbol, coefficient in other._coefficients.items():
144 coefficients[symbol] -= coefficient
145 constant = self._constant - other._constant
146 return Expression(coefficients, constant)
147
148 @_polymorphic
149 def __rsub__(self, other):
150 return other - self
151
152 def __mul__(self, other):
153 if isinstance(other, numbers.Rational):
154 coefficients = ((symbol, coefficient * other)
155 for symbol, coefficient in self._coefficients.items())
156 constant = self._constant * other
157 return Expression(coefficients, constant)
158 return NotImplemented
159
160 __rmul__ = __mul__
161
162 def __truediv__(self, other):
163 if isinstance(other, numbers.Rational):
164 coefficients = ((symbol, coefficient / other)
165 for symbol, coefficient in self._coefficients.items())
166 constant = self._constant / other
167 return Expression(coefficients, constant)
168 return NotImplemented
169
170 @_polymorphic
171 def __eq__(self, other):
172 # returns a boolean, not a constraint
173 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
174 return isinstance(other, Expression) and \
175 self._coefficients == other._coefficients and \
176 self._constant == other._constant
177
178 def __le__(self, other):
179 from .polyhedra import Le
180 return Le(self, other)
181
182 def __lt__(self, other):
183 from .polyhedra import Lt
184 return Lt(self, other)
185
186 def __ge__(self, other):
187 from .polyhedra import Ge
188 return Ge(self, other)
189
190 def __gt__(self, other):
191 from .polyhedra import Gt
192 return Gt(self, other)
193
194 def scaleint(self):
195 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
196 [value.denominator for value in self.values()])
197 return self * lcm
198
199 def subs(self, symbol, expression=None):
200 if expression is None:
201 if isinstance(symbol, Mapping):
202 symbol = symbol.items()
203 substitutions = symbol
204 else:
205 substitutions = [(symbol, expression)]
206 result = self
207 for symbol, expression in substitutions:
208 if not isinstance(symbol, Symbol):
209 raise TypeError('symbols must be Symbol instances')
210 coefficients = [(othersymbol, coefficient)
211 for othersymbol, coefficient in result._coefficients.items()
212 if othersymbol != symbol]
213 coefficient = result._coefficients.get(symbol, 0)
214 constant = result._constant
215 result = Expression(coefficients, constant) + coefficient*expression
216 return result
217
218 @classmethod
219 def _fromast(cls, node):
220 if isinstance(node, ast.Module) and len(node.body) == 1:
221 return cls._fromast(node.body[0])
222 elif isinstance(node, ast.Expr):
223 return cls._fromast(node.value)
224 elif isinstance(node, ast.Name):
225 return Symbol(node.id)
226 elif isinstance(node, ast.Num):
227 return Rational(node.n)
228 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
229 return -cls._fromast(node.operand)
230 elif isinstance(node, ast.BinOp):
231 left = cls._fromast(node.left)
232 right = cls._fromast(node.right)
233 if isinstance(node.op, ast.Add):
234 return left + right
235 elif isinstance(node.op, ast.Sub):
236 return left - right
237 elif isinstance(node.op, ast.Mult):
238 return left * right
239 elif isinstance(node.op, ast.Div):
240 return left / right
241 raise SyntaxError('invalid syntax')
242
243 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
244
245 @classmethod
246 def fromstring(cls, string):
247 # add implicit multiplication operators, e.g. '5x' -> '5*x'
248 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
249 tree = ast.parse(string, 'eval')
250 return cls._fromast(tree)
251
252 def __repr__(self):
253 string = ''
254 for i, (symbol, coefficient) in enumerate(self.coefficients()):
255 if coefficient == 1:
256 if i != 0:
257 string += ' + '
258 elif coefficient == -1:
259 string += '-' if i == 0 else ' - '
260 elif i == 0:
261 string += '{}*'.format(coefficient)
262 elif coefficient > 0:
263 string += ' + {}*'.format(coefficient)
264 else:
265 string += ' - {}*'.format(-coefficient)
266 string += '{}'.format(symbol)
267 constant = self.constant
268 if len(string) == 0:
269 string += '{}'.format(constant)
270 elif constant > 0:
271 string += ' + {}'.format(constant)
272 elif constant < 0:
273 string += ' - {}'.format(-constant)
274 return string
275
276 def _repr_latex_(self):
277 string = ''
278 for i, (symbol, coefficient) in enumerate(self.coefficients()):
279 if coefficient == 1:
280 if i != 0:
281 string += ' + '
282 elif coefficient == -1:
283 string += '-' if i == 0 else ' - '
284 elif i == 0:
285 string += '{}'.format(coefficient._repr_latex_().strip('$'))
286 elif coefficient > 0:
287 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
288 elif coefficient < 0:
289 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
290 string += '{}'.format(symbol._repr_latex_().strip('$'))
291 constant = self.constant
292 if len(string) == 0:
293 string += '{}'.format(constant._repr_latex_().strip('$'))
294 elif constant > 0:
295 string += ' + {}'.format(constant._repr_latex_().strip('$'))
296 elif constant < 0:
297 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
298 return '$${}$$'.format(string)
299
300 def _parenstr(self, always=False):
301 string = str(self)
302 if not always and (self.isconstant() or self.issymbol()):
303 return string
304 else:
305 return '({})'.format(string)
306
307 @classmethod
308 def fromsympy(cls, expr):
309 import sympy
310 coefficients = []
311 constant = 0
312 for symbol, coefficient in expr.as_coefficients_dict().items():
313 coefficient = Fraction(coefficient.p, coefficient.q)
314 if symbol == sympy.S.One:
315 constant = coefficient
316 elif isinstance(symbol, sympy.Symbol):
317 symbol = Symbol(symbol.name)
318 coefficients.append((symbol, coefficient))
319 else:
320 raise ValueError('non-linear expression: {!r}'.format(expr))
321 return Expression(coefficients, constant)
322
323 def tosympy(self):
324 import sympy
325 expr = 0
326 for symbol, coefficient in self.coefficients():
327 term = coefficient * sympy.Symbol(symbol.name)
328 expr += term
329 expr += self.constant
330 return expr
331
332
333 class Symbol(Expression):
334
335 def __new__(cls, name):
336 if not isinstance(name, str):
337 raise TypeError('name must be a string')
338 self = object().__new__(cls)
339 self._name = name.strip()
340 self._coefficients = {self: Fraction(1)}
341 self._constant = Fraction(0)
342 self._symbols = (self,)
343 self._dimension = 1
344 return self
345
346 @property
347 def name(self):
348 return self._name
349
350 def __hash__(self):
351 return hash(self.sortkey())
352
353 def sortkey(self):
354 return self.name,
355
356 def issymbol(self):
357 return True
358
359 def __eq__(self, other):
360 return self.sortkey() == other.sortkey()
361
362 def asdummy(self):
363 return Dummy(self.name)
364
365 @classmethod
366 def _fromast(cls, node):
367 if isinstance(node, ast.Module) and len(node.body) == 1:
368 return cls._fromast(node.body[0])
369 elif isinstance(node, ast.Expr):
370 return cls._fromast(node.value)
371 elif isinstance(node, ast.Name):
372 return Symbol(node.id)
373 raise SyntaxError('invalid syntax')
374
375 def __repr__(self):
376 return self.name
377
378 def _repr_latex_(self):
379 return '$${}$$'.format(self.name)
380
381 @classmethod
382 def fromsympy(cls, expr):
383 import sympy
384 if isinstance(expr, sympy.Dummy):
385 return Dummy(expr.name)
386 elif isinstance(expr, sympy.Symbol):
387 return Symbol(expr.name)
388 else:
389 raise TypeError('expr must be a sympy.Symbol instance')
390
391
392 class Dummy(Symbol):
393
394 _count = 0
395
396 def __new__(cls, name=None):
397 if name is None:
398 name = 'Dummy_{}'.format(Dummy._count)
399 elif not isinstance(name, str):
400 raise TypeError('name must be a string')
401 self = object().__new__(cls)
402 self._index = Dummy._count
403 self._name = name.strip()
404 self._coefficients = {self: Fraction(1)}
405 self._constant = Fraction(0)
406 self._symbols = (self,)
407 self._dimension = 1
408 Dummy._count += 1
409 return self
410
411 def __hash__(self):
412 return hash(self.sortkey())
413
414 def sortkey(self):
415 return self._name, self._index
416
417 def __repr__(self):
418 return '_{}'.format(self.name)
419
420 def _repr_latex_(self):
421 return '$${}_{{{}}}$$'.format(self.name, self._index)
422
423
424 def symbols(names):
425 if isinstance(names, str):
426 names = names.replace(',', ' ').split()
427 return tuple(Symbol(name) for name in names)
428
429
430 class Rational(Expression, Fraction):
431
432 def __new__(cls, numerator=0, denominator=None):
433 self = object().__new__(cls)
434 self._coefficients = {}
435 self._constant = Fraction(numerator, denominator)
436 self._symbols = ()
437 self._dimension = 0
438 self._numerator = self._constant.numerator
439 self._denominator = self._constant.denominator
440 return self
441
442 def __hash__(self):
443 return Fraction.__hash__(self)
444
445 @property
446 def constant(self):
447 return self
448
449 def isconstant(self):
450 return True
451
452 def __bool__(self):
453 return Fraction.__bool__(self)
454
455 def __repr__(self):
456 if self.denominator == 1:
457 return '{!r}'.format(self.numerator)
458 else:
459 return '{!r}/{!r}'.format(self.numerator, self.denominator)
460
461 def _repr_latex_(self):
462 if self.denominator == 1:
463 return '$${}$$'.format(self.numerator)
464 elif self.numerator < 0:
465 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
466 self.denominator)
467 else:
468 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
469 self.denominator)
470
471 @classmethod
472 def fromsympy(cls, expr):
473 import sympy
474 if isinstance(expr, sympy.Rational):
475 return Rational(expr.p, expr.q)
476 elif isinstance(expr, numbers.Rational):
477 return Rational(expr)
478 else:
479 raise TypeError('expr must be a sympy.Rational instance')