Better implementation of symbols and constants
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'symbols',
13 'Constant',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Constant(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 '_hash',
40 )
41
42 def __new__(cls, coefficients=None, constant=0):
43 if isinstance(coefficients, str):
44 if constant:
45 raise TypeError('too many arguments')
46 return cls.fromstring(coefficients)
47 if isinstance(coefficients, dict):
48 coefficients = coefficients.items()
49 if coefficients is None:
50 return Constant(constant)
51 coefficients = [(symbol, coefficient)
52 for symbol, coefficient in coefficients if coefficient != 0]
53 if len(coefficients) == 0:
54 return Constant(constant)
55 elif len(coefficients) == 1 and constant == 0:
56 symbol, coefficient = coefficients[0]
57 if coefficient == 1:
58 return Symbol(symbol)
59 self = object().__new__(cls)
60 self._coefficients = {}
61 for symbol, coefficient in coefficients:
62 if isinstance(symbol, Symbol):
63 symbol = symbol.name
64 elif not isinstance(symbol, str):
65 raise TypeError('symbols must be strings or Symbol instances')
66 if isinstance(coefficient, Constant):
67 coefficient = coefficient.constant
68 if not isinstance(coefficient, numbers.Rational):
69 raise TypeError('coefficients must be rational numbers '
70 'or Constant instances')
71 self._coefficients[symbol] = coefficient
72 self._coefficients = OrderedDict(sorted(self._coefficients.items()))
73 if isinstance(constant, Constant):
74 constant = constant.constant
75 if not isinstance(constant, numbers.Rational):
76 raise TypeError('constant must be a rational number '
77 'or a Constant instance')
78 self._constant = constant
79 self._symbols = tuple(self._coefficients)
80 self._dimension = len(self._symbols)
81 self._hash = hash((tuple(self._coefficients.items()), self._constant))
82 return self
83
84 def coefficient(self, symbol):
85 if isinstance(symbol, Symbol):
86 symbol = symbol.name
87 elif not isinstance(symbol, str):
88 raise TypeError('symbol must be a string or a Symbol instance')
89 try:
90 return self._coefficients[symbol]
91 except KeyError:
92 return 0
93
94 __getitem__ = coefficient
95
96 def coefficients(self):
97 yield from self._coefficients.items()
98
99 @property
100 def constant(self):
101 return self._constant
102
103 @property
104 def symbols(self):
105 return self._symbols
106
107 @property
108 def dimension(self):
109 return self._dimension
110
111 def __hash__(self):
112 return self._hash
113
114 def isconstant(self):
115 return False
116
117 def issymbol(self):
118 return False
119
120 def values(self):
121 for symbol in self.symbols:
122 yield self.coefficient(symbol)
123 yield self.constant
124
125 def __bool__(self):
126 return True
127
128 def __pos__(self):
129 return self
130
131 def __neg__(self):
132 return self * -1
133
134 @_polymorphic
135 def __add__(self, other):
136 coefficients = dict(self.coefficients())
137 for symbol, coefficient in other.coefficients():
138 if symbol in coefficients:
139 coefficients[symbol] += coefficient
140 else:
141 coefficients[symbol] = coefficient
142 constant = self.constant + other.constant
143 return Expression(coefficients, constant)
144
145 __radd__ = __add__
146
147 @_polymorphic
148 def __sub__(self, other):
149 coefficients = dict(self.coefficients())
150 for symbol, coefficient in other.coefficients():
151 if symbol in coefficients:
152 coefficients[symbol] -= coefficient
153 else:
154 coefficients[symbol] = -coefficient
155 constant = self.constant - other.constant
156 return Expression(coefficients, constant)
157
158 def __rsub__(self, other):
159 return -(self - other)
160
161 @_polymorphic
162 def __mul__(self, other):
163 if other.isconstant():
164 coefficients = dict(self.coefficients())
165 for symbol in coefficients:
166 coefficients[symbol] *= other.constant
167 constant = self.constant * other.constant
168 return Expression(coefficients, constant)
169 if isinstance(other, Expression) and not self.isconstant():
170 raise ValueError('non-linear expression: '
171 '{} * {}'.format(self._parenstr(), other._parenstr()))
172 return NotImplemented
173
174 __rmul__ = __mul__
175
176 @_polymorphic
177 def __truediv__(self, other):
178 if other.isconstant():
179 coefficients = dict(self.coefficients())
180 for symbol in coefficients:
181 coefficients[symbol] = \
182 Fraction(coefficients[symbol], other.constant)
183 constant = Fraction(self.constant, other.constant)
184 return Expression(coefficients, constant)
185 if isinstance(other, Expression):
186 raise ValueError('non-linear expression: '
187 '{} / {}'.format(self._parenstr(), other._parenstr()))
188 return NotImplemented
189
190 def __rtruediv__(self, other):
191 if isinstance(other, self):
192 if self.isconstant():
193 constant = Fraction(other, self.constant)
194 return Expression(constant=constant)
195 else:
196 raise ValueError('non-linear expression: '
197 '{} / {}'.format(other._parenstr(), self._parenstr()))
198 return NotImplemented
199
200 @_polymorphic
201 def __eq__(self, other):
202 # "normal" equality
203 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
204 return isinstance(other, Expression) and \
205 self._coefficients == other._coefficients and \
206 self.constant == other.constant
207
208 @_polymorphic
209 def __le__(self, other):
210 from .polyhedra import Le
211 return Le(self, other)
212
213 @_polymorphic
214 def __lt__(self, other):
215 from .polyhedra import Lt
216 return Lt(self, other)
217
218 @_polymorphic
219 def __ge__(self, other):
220 from .polyhedra import Ge
221 return Ge(self, other)
222
223 @_polymorphic
224 def __gt__(self, other):
225 from .polyhedra import Gt
226 return Gt(self, other)
227
228 def _toint(self):
229 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
230 [value.denominator for value in self.values()])
231 return self * lcm
232
233 @classmethod
234 def _fromast(cls, node):
235 if isinstance(node, ast.Module) and len(node.body) == 1:
236 return cls._fromast(node.body[0])
237 elif isinstance(node, ast.Expr):
238 return cls._fromast(node.value)
239 elif isinstance(node, ast.Name):
240 return Symbol(node.id)
241 elif isinstance(node, ast.Num):
242 return Constant(node.n)
243 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
244 return -cls._fromast(node.operand)
245 elif isinstance(node, ast.BinOp):
246 left = cls._fromast(node.left)
247 right = cls._fromast(node.right)
248 if isinstance(node.op, ast.Add):
249 return left + right
250 elif isinstance(node.op, ast.Sub):
251 return left - right
252 elif isinstance(node.op, ast.Mult):
253 return left * right
254 elif isinstance(node.op, ast.Div):
255 return left / right
256 raise SyntaxError('invalid syntax')
257
258 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
259
260 @classmethod
261 def fromstring(cls, string):
262 # add implicit multiplication operators, e.g. '5x' -> '5*x'
263 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
264 tree = ast.parse(string, 'eval')
265 return cls._fromast(tree)
266
267 def __str__(self):
268 string = ''
269 i = 0
270 for symbol in self.symbols:
271 coefficient = self.coefficient(symbol)
272 if coefficient == 1:
273 if i == 0:
274 string += symbol
275 else:
276 string += ' + {}'.format(symbol)
277 elif coefficient == -1:
278 if i == 0:
279 string += '-{}'.format(symbol)
280 else:
281 string += ' - {}'.format(symbol)
282 else:
283 if i == 0:
284 string += '{}*{}'.format(coefficient, symbol)
285 elif coefficient > 0:
286 string += ' + {}*{}'.format(coefficient, symbol)
287 else:
288 assert coefficient < 0
289 coefficient *= -1
290 string += ' - {}*{}'.format(coefficient, symbol)
291 i += 1
292 constant = self.constant
293 if constant != 0 and i == 0:
294 string += '{}'.format(constant)
295 elif constant > 0:
296 string += ' + {}'.format(constant)
297 elif constant < 0:
298 constant *= -1
299 string += ' - {}'.format(constant)
300 if string == '':
301 string = '0'
302 return string
303
304 def _parenstr(self, always=False):
305 string = str(self)
306 if not always and (self.isconstant() or self.issymbol()):
307 return string
308 else:
309 return '({})'.format(string)
310
311 def __repr__(self):
312 return '{}({!r})'.format(self.__class__.__name__, str(self))
313
314 @classmethod
315 def fromsympy(cls, expr):
316 import sympy
317 coefficients = {}
318 constant = 0
319 for symbol, coefficient in expr.as_coefficients_dict().items():
320 coefficient = Fraction(coefficient.p, coefficient.q)
321 if symbol == sympy.S.One:
322 constant = coefficient
323 elif isinstance(symbol, sympy.Symbol):
324 symbol = symbol.name
325 coefficients[symbol] = coefficient
326 else:
327 raise ValueError('non-linear expression: {!r}'.format(expr))
328 return cls(coefficients, constant)
329
330 def tosympy(self):
331 import sympy
332 expr = 0
333 for symbol, coefficient in self.coefficients():
334 term = coefficient * sympy.Symbol(symbol)
335 expr += term
336 expr += self.constant
337 return expr
338
339
340 class Symbol(Expression):
341
342 __slots__ = (
343 '_name',
344 '_hash',
345 )
346
347 def __new__(cls, name):
348 if isinstance(name, Symbol):
349 name = name.name
350 elif not isinstance(name, str):
351 raise TypeError('name must be a string or a Symbol instance')
352 name = name.strip()
353 self = object().__new__(cls)
354 self._name = name
355 self._hash = hash(self._name)
356 return self
357
358 @property
359 def name(self):
360 return self._name
361
362 def __hash__(self):
363 return self._hash
364
365 def coefficient(self, symbol):
366 if isinstance(symbol, Symbol):
367 symbol = symbol.name
368 elif not isinstance(symbol, str):
369 raise TypeError('symbol must be a string or a Symbol instance')
370 if symbol == self.name:
371 return 1
372 else:
373 return 0
374
375 def coefficients(self):
376 yield self.name, 1
377
378 @property
379 def constant(self):
380 return 0
381
382 @property
383 def symbols(self):
384 return self.name,
385
386 @property
387 def dimension(self):
388 return 1
389
390 def issymbol(self):
391 return True
392
393 def __eq__(self, other):
394 return isinstance(other, Symbol) and self.name == other.name
395
396 @classmethod
397 def _fromast(cls, node):
398 if isinstance(node, ast.Module) and len(node.body) == 1:
399 return cls._fromast(node.body[0])
400 elif isinstance(node, ast.Expr):
401 return cls._fromast(node.value)
402 elif isinstance(node, ast.Name):
403 return Symbol(node.id)
404 raise SyntaxError('invalid syntax')
405
406 def __repr__(self):
407 return '{}({!r})'.format(self.__class__.__name__, self._name)
408
409 @classmethod
410 def fromsympy(cls, expr):
411 import sympy
412 if isinstance(expr, sympy.Symbol):
413 return cls(expr.name)
414 else:
415 raise TypeError('expr must be a sympy.Symbol instance')
416
417
418 def symbols(names):
419 if isinstance(names, str):
420 names = names.replace(',', ' ').split()
421 return (Symbol(name) for name in names)
422
423
424 class Constant(Expression):
425
426 __slots__ = (
427 '_constant',
428 '_hash',
429 )
430
431 def __new__(cls, numerator=0, denominator=None):
432 self = object().__new__(cls)
433 if denominator is None and isinstance(numerator, Constant):
434 self._constant = numerator.constant
435 else:
436 self._constant = Fraction(numerator, denominator)
437 self._hash = hash(self._constant)
438 return self
439
440 def __hash__(self):
441 return self._hash
442
443 def coefficient(self, symbol):
444 if isinstance(symbol, Symbol):
445 symbol = symbol.name
446 elif not isinstance(symbol, str):
447 raise TypeError('symbol must be a string or a Symbol instance')
448 return 0
449
450 def coefficients(self):
451 yield from []
452
453 @property
454 def symbols(self):
455 return ()
456
457 @property
458 def dimension(self):
459 return 0
460
461 def isconstant(self):
462 return True
463
464 @_polymorphic
465 def __eq__(self, other):
466 return isinstance(other, Constant) and self.constant == other.constant
467
468 def __bool__(self):
469 return self.constant != 0
470
471 @classmethod
472 def fromstring(cls, string):
473 if isinstance(string, str):
474 return Constant(Fraction(string))
475 else:
476 raise TypeError('string must be a string instance')
477
478 def __repr__(self):
479 if self.constant.denominator == 1:
480 return '{}({!r})'.format(self.__class__.__name__,
481 self.constant.numerator)
482 else:
483 return '{}({!r}, {!r})'.format(self.__class__.__name__,
484 self.constant.numerator, self.constant.denominator)
485
486 @classmethod
487 def fromsympy(cls, expr):
488 import sympy
489 if isinstance(expr, sympy.Rational):
490 return cls(expr.p, expr.q)
491 elif isinstance(expr, numbers.Rational):
492 return cls(expr)
493 else:
494 raise TypeError('expr must be a sympy.Rational instance')