Simplify Expression.__mul__(), Expression.__truediv__()
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict, defaultdict, Mapping
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'Dummy', 'symbols',
13 'Rational',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Rational(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 def __new__(cls, coefficients=None, constant=0):
35 if isinstance(coefficients, str):
36 if constant != 0:
37 raise TypeError('too many arguments')
38 return Expression.fromstring(coefficients)
39 if coefficients is None:
40 return Rational(constant)
41 if isinstance(coefficients, Mapping):
42 coefficients = coefficients.items()
43 for symbol, coefficient in coefficients:
44 if not isinstance(symbol, Symbol):
45 raise TypeError('symbols must be Symbol instances')
46 if not isinstance(coefficient, numbers.Rational):
47 raise TypeError('coefficients must be rational numbers')
48 coefficients = [(symbol, Fraction(coefficient))
49 for symbol, coefficient in coefficients if coefficient != 0]
50 if not isinstance(constant, numbers.Rational):
51 raise TypeError('constant must be a rational number')
52 constant = Fraction(constant)
53 if len(coefficients) == 0:
54 return Rational(constant)
55 if len(coefficients) == 1 and constant == 0:
56 symbol, coefficient = coefficients[0]
57 if coefficient == 1:
58 return symbol
59 self = object().__new__(cls)
60 self._coefficients = OrderedDict(sorted(coefficients,
61 key=lambda item: item[0].sortkey()))
62 self._constant = constant
63 self._symbols = tuple(self._coefficients)
64 self._dimension = len(self._symbols)
65 return self
66
67 def coefficient(self, symbol):
68 if not isinstance(symbol, Symbol):
69 raise TypeError('symbol must be a Symbol instance')
70 try:
71 return Rational(self._coefficients[symbol])
72 except KeyError:
73 return Rational(0)
74
75 __getitem__ = coefficient
76
77 def coefficients(self):
78 for symbol, coefficient in self._coefficients.items():
79 yield symbol, Rational(coefficient)
80
81 @property
82 def constant(self):
83 return Rational(self._constant)
84
85 @property
86 def symbols(self):
87 return self._symbols
88
89 @property
90 def dimension(self):
91 return self._dimension
92
93 def __hash__(self):
94 return hash((tuple(self._coefficients.items()), self._constant))
95
96 def isconstant(self):
97 return False
98
99 def issymbol(self):
100 return False
101
102 def values(self):
103 for coefficient in self._coefficients.values():
104 yield Rational(coefficient)
105 yield Rational(self._constant)
106
107 def __bool__(self):
108 return True
109
110 def __pos__(self):
111 return self
112
113 def __neg__(self):
114 return self * -1
115
116 @_polymorphic
117 def __add__(self, other):
118 coefficients = defaultdict(Fraction, self._coefficients)
119 for symbol, coefficient in other._coefficients.items():
120 coefficients[symbol] += coefficient
121 constant = self._constant + other._constant
122 return Expression(coefficients, constant)
123
124 __radd__ = __add__
125
126 @_polymorphic
127 def __sub__(self, other):
128 coefficients = defaultdict(Fraction, self._coefficients)
129 for symbol, coefficient in other._coefficients.items():
130 coefficients[symbol] -= coefficient
131 constant = self._constant - other._constant
132 return Expression(coefficients, constant)
133
134 def __rsub__(self, other):
135 return -(self - other)
136
137 def __mul__(self, other):
138 if isinstance(other, numbers.Rational):
139 coefficients = dict(self._coefficients)
140 for symbol in coefficients:
141 coefficients[symbol] *= other
142 constant = self._constant * other
143 return Expression(coefficients, constant)
144 return NotImplemented
145
146 __rmul__ = __mul__
147
148 def __truediv__(self, other):
149 if isinstance(other, numbers.Rational):
150 coefficients = dict(self._coefficients)
151 for symbol in coefficients:
152 coefficients[symbol] /= other
153 constant = self._constant / other
154 # import pdb; pdb.set_trace()
155 return Expression(coefficients, constant)
156 return NotImplemented
157
158 @_polymorphic
159 def __eq__(self, other):
160 # "normal" equality
161 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
162 return isinstance(other, Expression) and \
163 self._coefficients == other._coefficients and \
164 self._constant == other._constant
165
166 @_polymorphic
167 def __le__(self, other):
168 from .polyhedra import Le
169 return Le(self, other)
170
171 @_polymorphic
172 def __lt__(self, other):
173 from .polyhedra import Lt
174 return Lt(self, other)
175
176 @_polymorphic
177 def __ge__(self, other):
178 from .polyhedra import Ge
179 return Ge(self, other)
180
181 @_polymorphic
182 def __gt__(self, other):
183 from .polyhedra import Gt
184 return Gt(self, other)
185
186 def scaleint(self):
187 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
188 [value.denominator for value in self.values()])
189 return self * lcm
190
191 def subs(self, symbol, expression=None):
192 if expression is None:
193 if isinstance(symbol, Mapping):
194 symbol = symbol.items()
195 substitutions = symbol
196 else:
197 substitutions = [(symbol, expression)]
198 result = self
199 for symbol, expression in substitutions:
200 if not isinstance(symbol, Symbol):
201 raise TypeError('symbols must be Symbol instances')
202 coefficients = [(othersymbol, coefficient)
203 for othersymbol, coefficient in result._coefficients.items()
204 if othersymbol != symbol]
205 coefficient = result._coefficients.get(symbol, 0)
206 constant = result._constant
207 result = Expression(coefficients, constant) + coefficient*expression
208 return result
209
210 @classmethod
211 def _fromast(cls, node):
212 if isinstance(node, ast.Module) and len(node.body) == 1:
213 return cls._fromast(node.body[0])
214 elif isinstance(node, ast.Expr):
215 return cls._fromast(node.value)
216 elif isinstance(node, ast.Name):
217 return Symbol(node.id)
218 elif isinstance(node, ast.Num):
219 return Rational(node.n)
220 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
221 return -cls._fromast(node.operand)
222 elif isinstance(node, ast.BinOp):
223 left = cls._fromast(node.left)
224 right = cls._fromast(node.right)
225 if isinstance(node.op, ast.Add):
226 return left + right
227 elif isinstance(node.op, ast.Sub):
228 return left - right
229 elif isinstance(node.op, ast.Mult):
230 return left * right
231 elif isinstance(node.op, ast.Div):
232 return left / right
233 raise SyntaxError('invalid syntax')
234
235 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
236
237 @classmethod
238 def fromstring(cls, string):
239 # add implicit multiplication operators, e.g. '5x' -> '5*x'
240 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
241 tree = ast.parse(string, 'eval')
242 return cls._fromast(tree)
243
244 def __repr__(self):
245 string = ''
246 for i, (symbol, coefficient) in enumerate(self.coefficients()):
247 if coefficient == 1:
248 if i != 0:
249 string += ' + '
250 elif coefficient == -1:
251 string += '-' if i == 0 else ' - '
252 elif i == 0:
253 string += '{}*'.format(coefficient)
254 elif coefficient > 0:
255 string += ' + {}*'.format(coefficient)
256 else:
257 string += ' - {}*'.format(-coefficient)
258 string += '{}'.format(symbol)
259 constant = self.constant
260 if len(string) == 0:
261 string += '{}'.format(constant)
262 elif constant > 0:
263 string += ' + {}'.format(constant)
264 elif constant < 0:
265 string += ' - {}'.format(-constant)
266 return string
267
268 def _repr_latex_(self):
269 string = ''
270 for i, (symbol, coefficient) in enumerate(self.coefficients()):
271 if coefficient == 1:
272 if i != 0:
273 string += ' + '
274 elif coefficient == -1:
275 string += '-' if i == 0 else ' - '
276 elif i == 0:
277 string += '{}'.format(coefficient._repr_latex_().strip('$'))
278 elif coefficient > 0:
279 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
280 elif coefficient < 0:
281 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
282 string += '{}'.format(symbol._repr_latex_().strip('$'))
283 constant = self.constant
284 if len(string) == 0:
285 string += '{}'.format(constant._repr_latex_().strip('$'))
286 elif constant > 0:
287 string += ' + {}'.format(constant._repr_latex_().strip('$'))
288 elif constant < 0:
289 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
290 return '${}$'.format(string)
291
292 def _parenstr(self, always=False):
293 string = str(self)
294 if not always and (self.isconstant() or self.issymbol()):
295 return string
296 else:
297 return '({})'.format(string)
298
299 @classmethod
300 def fromsympy(cls, expr):
301 import sympy
302 coefficients = []
303 constant = 0
304 for symbol, coefficient in expr.as_coefficients_dict().items():
305 coefficient = Fraction(coefficient.p, coefficient.q)
306 if symbol == sympy.S.One:
307 constant = coefficient
308 elif isinstance(symbol, sympy.Symbol):
309 symbol = Symbol(symbol.name)
310 coefficients.append((symbol, coefficient))
311 else:
312 raise ValueError('non-linear expression: {!r}'.format(expr))
313 return Expression(coefficients, constant)
314
315 def tosympy(self):
316 import sympy
317 expr = 0
318 for symbol, coefficient in self.coefficients():
319 term = coefficient * sympy.Symbol(symbol.name)
320 expr += term
321 expr += self.constant
322 return expr
323
324
325 class Symbol(Expression):
326
327 def __new__(cls, name):
328 if not isinstance(name, str):
329 raise TypeError('name must be a string')
330 self = object().__new__(cls)
331 self._name = name.strip()
332 self._coefficients = {self: Fraction(1)}
333 self._constant = Fraction(0)
334 self._symbols = (self,)
335 self._dimension = 1
336 return self
337
338 @property
339 def name(self):
340 return self._name
341
342 def __hash__(self):
343 return hash(self.sortkey())
344
345 def sortkey(self):
346 return self.name,
347
348 def issymbol(self):
349 return True
350
351 def __eq__(self, other):
352 return not isinstance(other, Dummy) and isinstance(other, Symbol) \
353 and self.name == other.name
354
355 def asdummy(self):
356 return Dummy(self.name)
357
358 @classmethod
359 def _fromast(cls, node):
360 if isinstance(node, ast.Module) and len(node.body) == 1:
361 return cls._fromast(node.body[0])
362 elif isinstance(node, ast.Expr):
363 return cls._fromast(node.value)
364 elif isinstance(node, ast.Name):
365 return Symbol(node.id)
366 raise SyntaxError('invalid syntax')
367
368 def __repr__(self):
369 return self.name
370
371 def _repr_latex_(self):
372 return '${}$'.format(self.name)
373
374 @classmethod
375 def fromsympy(cls, expr):
376 import sympy
377 if isinstance(expr, sympy.Symbol):
378 return cls(expr.name)
379 else:
380 raise TypeError('expr must be a sympy.Symbol instance')
381
382
383 class Dummy(Symbol):
384
385 _count = 0
386
387 def __new__(cls, name=None):
388 if name is None:
389 name = 'Dummy_{}'.format(Dummy._count)
390 self = object().__new__(cls)
391 self._index = Dummy._count
392 self._name = name.strip()
393 self._coefficients = {self: Fraction(1)}
394 self._constant = Fraction(0)
395 self._symbols = (self,)
396 self._dimension = 1
397 Dummy._count += 1
398 return self
399
400 def __hash__(self):
401 return hash(self.sortkey())
402
403 def sortkey(self):
404 return self._name, self._index
405
406 def __eq__(self, other):
407 return isinstance(other, Dummy) and self._index == other._index
408
409 def __repr__(self):
410 return '_{}'.format(self.name)
411
412 def _repr_latex_(self):
413 return '${}_{{{}}}$'.format(self.name, self._index)
414
415
416 def symbols(names):
417 if isinstance(names, str):
418 names = names.replace(',', ' ').split()
419 return tuple(Symbol(name) for name in names)
420
421
422 class Rational(Expression, Fraction):
423
424 def __new__(cls, numerator=0, denominator=None):
425 self = Fraction.__new__(cls, numerator, denominator)
426 self._coefficients = {}
427 self._constant = Fraction(self)
428 self._symbols = ()
429 self._dimension = 0
430 return self
431
432 def __hash__(self):
433 return Fraction.__hash__(self)
434
435 @property
436 def constant(self):
437 return self
438
439 def isconstant(self):
440 return True
441
442 def __bool__(self):
443 return Fraction.__bool__(self)
444
445 @classmethod
446 def fromstring(cls, string):
447 if not isinstance(string, str):
448 raise TypeError('string must be a string instance')
449 return Rational(string)
450
451 def __repr__(self):
452 if self.denominator == 1:
453 return '{!r}'.format(self.numerator)
454 else:
455 return '{!r}/{!r}'.format(self.numerator, self.denominator)
456
457 def _repr_latex_(self):
458 if self.denominator == 1:
459 return '${}$'.format(self.numerator)
460 elif self.numerator < 0:
461 return '$-\\frac{{{}}}{{{}}}$'.format(-self.numerator,
462 self.denominator)
463 else:
464 return '$\\frac{{{}}}{{{}}}$'.format(self.numerator,
465 self.denominator)
466
467 @classmethod
468 def fromsympy(cls, expr):
469 import sympy
470 if isinstance(expr, sympy.Rational):
471 return Rational(expr.p, expr.q)
472 elif isinstance(expr, numbers.Rational):
473 return Rational(expr)
474 else:
475 raise TypeError('expr must be a sympy.Rational instance')