Cleaner implementation of Rational
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict, defaultdict, Mapping
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'Dummy', 'symbols',
13 'Rational',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Rational(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 def __new__(cls, coefficients=None, constant=0):
35 if isinstance(coefficients, str):
36 if constant != 0:
37 raise TypeError('too many arguments')
38 return Expression.fromstring(coefficients)
39 if coefficients is None:
40 return Rational(constant)
41 if isinstance(coefficients, Mapping):
42 coefficients = coefficients.items()
43 for symbol, coefficient in coefficients:
44 if not isinstance(symbol, Symbol):
45 raise TypeError('symbols must be Symbol instances')
46 if not isinstance(coefficient, numbers.Rational):
47 raise TypeError('coefficients must be Rational instances')
48 coefficients = [(symbol, Fraction(coefficient))
49 for symbol, coefficient in coefficients if coefficient != 0]
50 if not isinstance(constant, numbers.Rational):
51 raise TypeError('constant must be a Rational instance')
52 constant = Fraction(constant)
53 if len(coefficients) == 0:
54 return Rational(constant)
55 if len(coefficients) == 1 and constant == 0:
56 symbol, coefficient = coefficients[0]
57 if coefficient == 1:
58 return symbol
59 self = object().__new__(cls)
60 self._coefficients = OrderedDict(sorted(coefficients,
61 key=lambda item: item[0].sortkey()))
62 self._constant = constant
63 self._symbols = tuple(self._coefficients)
64 self._dimension = len(self._symbols)
65 return self
66
67 def coefficient(self, symbol):
68 if not isinstance(symbol, Symbol):
69 raise TypeError('symbol must be a Symbol instance')
70 try:
71 return Rational(self._coefficients[symbol])
72 except KeyError:
73 return Rational(0)
74
75 __getitem__ = coefficient
76
77 def coefficients(self):
78 for symbol, coefficient in self._coefficients.items():
79 yield symbol, Rational(coefficient)
80
81 @property
82 def constant(self):
83 return Rational(self._constant)
84
85 @property
86 def symbols(self):
87 return self._symbols
88
89 @property
90 def dimension(self):
91 return self._dimension
92
93 def __hash__(self):
94 return hash((tuple(self._coefficients.items()), self._constant))
95
96 def isconstant(self):
97 return False
98
99 def issymbol(self):
100 return False
101
102 def values(self):
103 for coefficient in self._coefficients.values():
104 yield Rational(coefficient)
105 yield Rational(self._constant)
106
107 def __bool__(self):
108 return True
109
110 def __pos__(self):
111 return self
112
113 def __neg__(self):
114 return self * -1
115
116 @_polymorphic
117 def __add__(self, other):
118 coefficients = defaultdict(Fraction, self._coefficients)
119 for symbol, coefficient in other._coefficients.items():
120 coefficients[symbol] += coefficient
121 constant = self._constant + other._constant
122 return Expression(coefficients, constant)
123
124 __radd__ = __add__
125
126 @_polymorphic
127 def __sub__(self, other):
128 coefficients = defaultdict(Fraction, self._coefficients)
129 for symbol, coefficient in other._coefficients.items():
130 coefficients[symbol] -= coefficient
131 constant = self._constant - other._constant
132 return Expression(coefficients, constant)
133
134 def __rsub__(self, other):
135 return -(self - other)
136
137 @_polymorphic
138 def __mul__(self, other):
139 if isinstance(other, Rational):
140 return other.__rmul__(self)
141 return NotImplemented
142
143 __rmul__ = __mul__
144
145 @_polymorphic
146 def __truediv__(self, other):
147 if isinstance(other, Rational):
148 return other.__rtruediv__(self)
149 return NotImplemented
150
151 __rtruediv__ = __truediv__
152
153 @_polymorphic
154 def __eq__(self, other):
155 # "normal" equality
156 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
157 return isinstance(other, Expression) and \
158 self._coefficients == other._coefficients and \
159 self._constant == other._constant
160
161 @_polymorphic
162 def __le__(self, other):
163 from .polyhedra import Le
164 return Le(self, other)
165
166 @_polymorphic
167 def __lt__(self, other):
168 from .polyhedra import Lt
169 return Lt(self, other)
170
171 @_polymorphic
172 def __ge__(self, other):
173 from .polyhedra import Ge
174 return Ge(self, other)
175
176 @_polymorphic
177 def __gt__(self, other):
178 from .polyhedra import Gt
179 return Gt(self, other)
180
181 def scaleint(self):
182 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
183 [value.denominator for value in self.values()])
184 return self * lcm
185
186 def subs(self, symbol, expression=None):
187 if expression is None:
188 if isinstance(symbol, Mapping):
189 symbol = symbol.items()
190 substitutions = symbol
191 else:
192 substitutions = [(symbol, expression)]
193 result = self
194 for symbol, expression in substitutions:
195 if not isinstance(symbol, Symbol):
196 raise TypeError('symbols must be Symbol instances')
197 coefficients = [(othersymbol, coefficient)
198 for othersymbol, coefficient in result._coefficients.items()
199 if othersymbol != symbol]
200 coefficient = result._coefficients.get(symbol, 0)
201 constant = result._constant
202 result = Expression(coefficients, constant) + coefficient*expression
203 return result
204
205 @classmethod
206 def _fromast(cls, node):
207 if isinstance(node, ast.Module) and len(node.body) == 1:
208 return cls._fromast(node.body[0])
209 elif isinstance(node, ast.Expr):
210 return cls._fromast(node.value)
211 elif isinstance(node, ast.Name):
212 return Symbol(node.id)
213 elif isinstance(node, ast.Num):
214 return Rational(node.n)
215 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
216 return -cls._fromast(node.operand)
217 elif isinstance(node, ast.BinOp):
218 left = cls._fromast(node.left)
219 right = cls._fromast(node.right)
220 if isinstance(node.op, ast.Add):
221 return left + right
222 elif isinstance(node.op, ast.Sub):
223 return left - right
224 elif isinstance(node.op, ast.Mult):
225 return left * right
226 elif isinstance(node.op, ast.Div):
227 return left / right
228 raise SyntaxError('invalid syntax')
229
230 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
231
232 @classmethod
233 def fromstring(cls, string):
234 # add implicit multiplication operators, e.g. '5x' -> '5*x'
235 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
236 tree = ast.parse(string, 'eval')
237 return cls._fromast(tree)
238
239 def __repr__(self):
240 string = ''
241 for i, (symbol, coefficient) in enumerate(self.coefficients()):
242 if coefficient == 1:
243 string += '' if i == 0 else ' + '
244 string += '{!r}'.format(symbol)
245 elif coefficient == -1:
246 string += '-' if i == 0 else ' - '
247 string += '{!r}'.format(symbol)
248 else:
249 if i == 0:
250 string += '{}*{!r}'.format(coefficient, symbol)
251 elif coefficient > 0:
252 string += ' + {}*{!r}'.format(coefficient, symbol)
253 else:
254 string += ' - {}*{!r}'.format(-coefficient, symbol)
255 constant = self.constant
256 if len(string) == 0:
257 string += '{}'.format(constant)
258 elif constant > 0:
259 string += ' + {}'.format(constant)
260 elif constant < 0:
261 string += ' - {}'.format(-constant)
262 return string
263
264 def _parenstr(self, always=False):
265 string = str(self)
266 if not always and (self.isconstant() or self.issymbol()):
267 return string
268 else:
269 return '({})'.format(string)
270
271 @classmethod
272 def fromsympy(cls, expr):
273 import sympy
274 coefficients = []
275 constant = 0
276 for symbol, coefficient in expr.as_coefficients_dict().items():
277 coefficient = Fraction(coefficient.p, coefficient.q)
278 if symbol == sympy.S.One:
279 constant = coefficient
280 elif isinstance(symbol, sympy.Symbol):
281 symbol = Symbol(symbol.name)
282 coefficients.append((symbol, coefficient))
283 else:
284 raise ValueError('non-linear expression: {!r}'.format(expr))
285 return Expression(coefficients, constant)
286
287 def tosympy(self):
288 import sympy
289 expr = 0
290 for symbol, coefficient in self.coefficients():
291 term = coefficient * sympy.Symbol(symbol.name)
292 expr += term
293 expr += self.constant
294 return expr
295
296
297 class Symbol(Expression):
298
299 def __new__(cls, name):
300 if not isinstance(name, str):
301 raise TypeError('name must be a string')
302 self = object().__new__(cls)
303 self._name = name.strip()
304 self._coefficients = {self: 1}
305 self._constant = 0
306 self._symbols = (self,)
307 self._dimension = 1
308 return self
309
310 @property
311 def name(self):
312 return self._name
313
314 def __hash__(self):
315 return hash(self.sortkey())
316
317 def sortkey(self):
318 return self.name,
319
320 def issymbol(self):
321 return True
322
323 def __eq__(self, other):
324 return not isinstance(other, Dummy) and isinstance(other, Symbol) \
325 and self.name == other.name
326
327 def asdummy(self):
328 return Dummy(self.name)
329
330 @classmethod
331 def _fromast(cls, node):
332 if isinstance(node, ast.Module) and len(node.body) == 1:
333 return cls._fromast(node.body[0])
334 elif isinstance(node, ast.Expr):
335 return cls._fromast(node.value)
336 elif isinstance(node, ast.Name):
337 return Symbol(node.id)
338 raise SyntaxError('invalid syntax')
339
340 def __repr__(self):
341 return self.name
342
343 @classmethod
344 def fromsympy(cls, expr):
345 import sympy
346 if isinstance(expr, sympy.Symbol):
347 return cls(expr.name)
348 else:
349 raise TypeError('expr must be a sympy.Symbol instance')
350
351
352 class Dummy(Symbol):
353
354 _count = 0
355
356 def __new__(cls, name=None):
357 if name is None:
358 name = 'Dummy_{}'.format(Dummy._count)
359 self = object().__new__(cls)
360 self._index = Dummy._count
361 self._name = name.strip()
362 self._coefficients = {self: 1}
363 self._constant = 0
364 self._symbols = (self,)
365 self._dimension = 1
366 Dummy._count += 1
367 return self
368
369 def __hash__(self):
370 return hash(self.sortkey())
371
372 def sortkey(self):
373 return self._name, self._index
374
375 def __eq__(self, other):
376 return isinstance(other, Dummy) and self._index == other._index
377
378 def __repr__(self):
379 return '_{}'.format(self.name)
380
381
382 def symbols(names):
383 if isinstance(names, str):
384 names = names.replace(',', ' ').split()
385 return tuple(Symbol(name) for name in names)
386
387
388 class Rational(Expression, Fraction):
389
390 def __new__(cls, numerator=0, denominator=None):
391 self = Fraction.__new__(cls, numerator, denominator)
392 self._coefficients = {}
393 self._constant = Fraction(self)
394 self._symbols = ()
395 self._dimension = 0
396 return self
397
398 def __hash__(self):
399 return Fraction.__hash__(self)
400
401 @property
402 def constant(self):
403 return self
404
405 def isconstant(self):
406 return True
407
408 def __bool__(self):
409 return Fraction.__bool__(self)
410
411 @_polymorphic
412 def __mul__(self, other):
413 coefficients = dict(other._coefficients)
414 for symbol in coefficients:
415 coefficients[symbol] *= self._constant
416 constant = other._constant * self._constant
417 return Expression(coefficients, constant)
418
419 __rmul__ = __mul__
420
421 @_polymorphic
422 def __rtruediv__(self, other):
423 coefficients = dict(other._coefficients)
424 for symbol in coefficients:
425 coefficients[symbol] /= self._constant
426 constant = other._constant / self._constant
427 return Expression(coefficients, constant)
428
429 @classmethod
430 def fromstring(cls, string):
431 if not isinstance(string, str):
432 raise TypeError('string must be a string instance')
433 return Rational(Fraction(string))
434
435 @classmethod
436 def fromsympy(cls, expr):
437 import sympy
438 if isinstance(expr, sympy.Rational):
439 return Rational(expr.p, expr.q)
440 elif isinstance(expr, numbers.Rational):
441 return Rational(expr)
442 else:
443 raise TypeError('expr must be a sympy.Rational instance')