Use OrderedDict to store Expression coefficients
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'symbols',
13 'Constant',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Constant(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 )
40
41 def __new__(cls, coefficients=None, constant=0):
42 if isinstance(coefficients, str):
43 if constant:
44 raise TypeError('too many arguments')
45 return cls.fromstring(coefficients)
46 if isinstance(coefficients, dict):
47 coefficients = coefficients.items()
48 if coefficients is None:
49 return Constant(constant)
50 coefficients = [(symbol, coefficient)
51 for symbol, coefficient in coefficients if coefficient != 0]
52 if len(coefficients) == 0:
53 return Constant(constant)
54 elif len(coefficients) == 1 and constant == 0:
55 symbol, coefficient = coefficients[0]
56 if coefficient == 1:
57 return Symbol(symbol)
58 self = object().__new__(cls)
59 self._coefficients = {}
60 for symbol, coefficient in coefficients:
61 if isinstance(symbol, Symbol):
62 symbol = symbol.name
63 elif not isinstance(symbol, str):
64 raise TypeError('symbols must be strings or Symbol instances')
65 if isinstance(coefficient, Constant):
66 coefficient = coefficient.constant
67 if not isinstance(coefficient, numbers.Rational):
68 raise TypeError('coefficients must be rational numbers '
69 'or Constant instances')
70 self._coefficients[symbol] = coefficient
71 self._coefficients = OrderedDict(sorted(self._coefficients.items()))
72 if isinstance(constant, Constant):
73 constant = constant.constant
74 if not isinstance(constant, numbers.Rational):
75 raise TypeError('constant must be a rational number '
76 'or a Constant instance')
77 self._constant = constant
78 self._symbols = tuple(self._coefficients)
79 self._dimension = len(self._symbols)
80 return self
81
82 def coefficient(self, symbol):
83 if isinstance(symbol, Symbol):
84 symbol = str(symbol)
85 elif not isinstance(symbol, str):
86 raise TypeError('symbol must be a string or a Symbol instance')
87 try:
88 return self._coefficients[symbol]
89 except KeyError:
90 return 0
91
92 __getitem__ = coefficient
93
94 def coefficients(self):
95 yield from self._coefficients.items()
96
97 @property
98 def constant(self):
99 return self._constant
100
101 @property
102 def symbols(self):
103 return self._symbols
104
105 @property
106 def dimension(self):
107 return self._dimension
108
109 def isconstant(self):
110 return False
111
112 def issymbol(self):
113 return False
114
115 def values(self):
116 for symbol in self.symbols:
117 yield self.coefficient(symbol)
118 yield self.constant
119
120 def __bool__(self):
121 return True
122
123 def __pos__(self):
124 return self
125
126 def __neg__(self):
127 return self * -1
128
129 @_polymorphic
130 def __add__(self, other):
131 coefficients = dict(self.coefficients())
132 for symbol, coefficient in other.coefficients():
133 if symbol in coefficients:
134 coefficients[symbol] += coefficient
135 else:
136 coefficients[symbol] = coefficient
137 constant = self.constant + other.constant
138 return Expression(coefficients, constant)
139
140 __radd__ = __add__
141
142 @_polymorphic
143 def __sub__(self, other):
144 coefficients = dict(self.coefficients())
145 for symbol, coefficient in other.coefficients():
146 if symbol in coefficients:
147 coefficients[symbol] -= coefficient
148 else:
149 coefficients[symbol] = -coefficient
150 constant = self.constant - other.constant
151 return Expression(coefficients, constant)
152
153 def __rsub__(self, other):
154 return -(self - other)
155
156 @_polymorphic
157 def __mul__(self, other):
158 if other.isconstant():
159 coefficients = dict(self.coefficients())
160 for symbol in coefficients:
161 coefficients[symbol] *= other.constant
162 constant = self.constant * other.constant
163 return Expression(coefficients, constant)
164 if isinstance(other, Expression) and not self.isconstant():
165 raise ValueError('non-linear expression: '
166 '{} * {}'.format(self._parenstr(), other._parenstr()))
167 return NotImplemented
168
169 __rmul__ = __mul__
170
171 @_polymorphic
172 def __truediv__(self, other):
173 if other.isconstant():
174 coefficients = dict(self.coefficients())
175 for symbol in coefficients:
176 coefficients[symbol] = \
177 Fraction(coefficients[symbol], other.constant)
178 constant = Fraction(self.constant, other.constant)
179 return Expression(coefficients, constant)
180 if isinstance(other, Expression):
181 raise ValueError('non-linear expression: '
182 '{} / {}'.format(self._parenstr(), other._parenstr()))
183 return NotImplemented
184
185 def __rtruediv__(self, other):
186 if isinstance(other, self):
187 if self.isconstant():
188 constant = Fraction(other, self.constant)
189 return Expression(constant=constant)
190 else:
191 raise ValueError('non-linear expression: '
192 '{} / {}'.format(other._parenstr(), self._parenstr()))
193 return NotImplemented
194
195 @_polymorphic
196 def __eq__(self, other):
197 # "normal" equality
198 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
199 return isinstance(other, Expression) and \
200 self._coefficients == other._coefficients and \
201 self.constant == other.constant
202
203 @_polymorphic
204 def __le__(self, other):
205 from .polyhedra import Le
206 return Le(self, other)
207
208 @_polymorphic
209 def __lt__(self, other):
210 from .polyhedra import Lt
211 return Lt(self, other)
212
213 @_polymorphic
214 def __ge__(self, other):
215 from .polyhedra import Ge
216 return Ge(self, other)
217
218 @_polymorphic
219 def __gt__(self, other):
220 from .polyhedra import Gt
221 return Gt(self, other)
222
223 def __hash__(self):
224 return hash((tuple(self.coefficients()), self._constant))
225
226 def _toint(self):
227 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
228 [value.denominator for value in self.values()])
229 return self * lcm
230
231 @classmethod
232 def _fromast(cls, node):
233 if isinstance(node, ast.Module) and len(node.body) == 1:
234 return cls._fromast(node.body[0])
235 elif isinstance(node, ast.Expr):
236 return cls._fromast(node.value)
237 elif isinstance(node, ast.Name):
238 return Symbol(node.id)
239 elif isinstance(node, ast.Num):
240 return Constant(node.n)
241 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
242 return -cls._fromast(node.operand)
243 elif isinstance(node, ast.BinOp):
244 left = cls._fromast(node.left)
245 right = cls._fromast(node.right)
246 if isinstance(node.op, ast.Add):
247 return left + right
248 elif isinstance(node.op, ast.Sub):
249 return left - right
250 elif isinstance(node.op, ast.Mult):
251 return left * right
252 elif isinstance(node.op, ast.Div):
253 return left / right
254 raise SyntaxError('invalid syntax')
255
256 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
257
258 @classmethod
259 def fromstring(cls, string):
260 # add implicit multiplication operators, e.g. '5x' -> '5*x'
261 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
262 tree = ast.parse(string, 'eval')
263 return cls._fromast(tree)
264
265 def __str__(self):
266 string = ''
267 i = 0
268 for symbol in self.symbols:
269 coefficient = self.coefficient(symbol)
270 if coefficient == 1:
271 if i == 0:
272 string += symbol
273 else:
274 string += ' + {}'.format(symbol)
275 elif coefficient == -1:
276 if i == 0:
277 string += '-{}'.format(symbol)
278 else:
279 string += ' - {}'.format(symbol)
280 else:
281 if i == 0:
282 string += '{}*{}'.format(coefficient, symbol)
283 elif coefficient > 0:
284 string += ' + {}*{}'.format(coefficient, symbol)
285 else:
286 assert coefficient < 0
287 coefficient *= -1
288 string += ' - {}*{}'.format(coefficient, symbol)
289 i += 1
290 constant = self.constant
291 if constant != 0 and i == 0:
292 string += '{}'.format(constant)
293 elif constant > 0:
294 string += ' + {}'.format(constant)
295 elif constant < 0:
296 constant *= -1
297 string += ' - {}'.format(constant)
298 if string == '':
299 string = '0'
300 return string
301
302 def _parenstr(self, always=False):
303 string = str(self)
304 if not always and (self.isconstant() or self.issymbol()):
305 return string
306 else:
307 return '({})'.format(string)
308
309 def __repr__(self):
310 return '{}({!r})'.format(self.__class__.__name__, str(self))
311
312 @classmethod
313 def fromsympy(cls, expr):
314 import sympy
315 coefficients = {}
316 constant = 0
317 for symbol, coefficient in expr.as_coefficients_dict().items():
318 coefficient = Fraction(coefficient.p, coefficient.q)
319 if symbol == sympy.S.One:
320 constant = coefficient
321 elif isinstance(symbol, sympy.Symbol):
322 symbol = symbol.name
323 coefficients[symbol] = coefficient
324 else:
325 raise ValueError('non-linear expression: {!r}'.format(expr))
326 return cls(coefficients, constant)
327
328 def tosympy(self):
329 import sympy
330 expr = 0
331 for symbol, coefficient in self.coefficients():
332 term = coefficient * sympy.Symbol(symbol)
333 expr += term
334 expr += self.constant
335 return expr
336
337
338 class Symbol(Expression):
339
340 __slots__ = Expression.__slots__ + (
341 '_name',
342 )
343
344 def __new__(cls, name):
345 if isinstance(name, Symbol):
346 name = name.name
347 elif not isinstance(name, str):
348 raise TypeError('name must be a string or a Symbol instance')
349 name = name.strip()
350 self = object().__new__(cls)
351 self._coefficients = {name: 1}
352 self._constant = 0
353 self._symbols = tuple(name)
354 self._name = name
355 self._dimension = 1
356 return self
357
358 @property
359 def name(self):
360 return self._name
361
362 def issymbol(self):
363 return True
364
365 @classmethod
366 def _fromast(cls, node):
367 if isinstance(node, ast.Module) and len(node.body) == 1:
368 return cls._fromast(node.body[0])
369 elif isinstance(node, ast.Expr):
370 return cls._fromast(node.value)
371 elif isinstance(node, ast.Name):
372 return Symbol(node.id)
373 raise SyntaxError('invalid syntax')
374
375 def __repr__(self):
376 return '{}({!r})'.format(self.__class__.__name__, self._name)
377
378 @classmethod
379 def fromsympy(cls, expr):
380 import sympy
381 if isinstance(expr, sympy.Symbol):
382 return cls(expr.name)
383 else:
384 raise TypeError('expr must be a sympy.Symbol instance')
385
386
387 def symbols(names):
388 if isinstance(names, str):
389 names = names.replace(',', ' ').split()
390 return (Symbol(name) for name in names)
391
392
393 class Constant(Expression):
394
395 def __new__(cls, numerator=0, denominator=None):
396 self = object().__new__(cls)
397 if denominator is None and isinstance(numerator, Constant):
398 self._constant = numerator.constant
399 else:
400 self._constant = Fraction(numerator, denominator)
401 self._coefficients = {}
402 self._symbols = ()
403 self._dimension = 0
404 return self
405
406 def isconstant(self):
407 return True
408
409 def __bool__(self):
410 return self.constant != 0
411
412 @classmethod
413 def fromstring(cls, string):
414 if isinstance(string, str):
415 return Constant(Fraction(string))
416 else:
417 raise TypeError('string must be a string instance')
418
419 def __repr__(self):
420 if self.constant.denominator == 1:
421 return '{}({!r})'.format(self.__class__.__name__,
422 self.constant.numerator)
423 else:
424 return '{}({!r}, {!r})'.format(self.__class__.__name__,
425 self.constant.numerator, self.constant.denominator)
426
427 @classmethod
428 def fromsympy(cls, expr):
429 import sympy
430 if isinstance(expr, sympy.Rational):
431 return cls(expr.p, expr.q)
432 elif isinstance(expr, numbers.Rational):
433 return cls(expr)
434 else:
435 raise TypeError('expr must be a sympy.Rational instance')