Make symbolnames return a tuple
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'symbols', 'symbolname', 'symbolnames',
13 'Constant',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Constant(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 '_hash',
40 )
41
42 def __new__(cls, coefficients=None, constant=0):
43 if isinstance(coefficients, str):
44 if constant:
45 raise TypeError('too many arguments')
46 return cls.fromstring(coefficients)
47 if isinstance(coefficients, dict):
48 coefficients = coefficients.items()
49 if coefficients is None:
50 return Constant(constant)
51 coefficients = [(symbol, coefficient)
52 for symbol, coefficient in coefficients if coefficient != 0]
53 if len(coefficients) == 0:
54 return Constant(constant)
55 elif len(coefficients) == 1 and constant == 0:
56 symbol, coefficient = coefficients[0]
57 if coefficient == 1:
58 return Symbol(symbol)
59 self = object().__new__(cls)
60 self._coefficients = {}
61 for symbol, coefficient in coefficients:
62 symbol = symbolname(symbol)
63 if isinstance(coefficient, Constant):
64 coefficient = coefficient.constant
65 if not isinstance(coefficient, numbers.Rational):
66 raise TypeError('coefficients must be rational numbers '
67 'or Constant instances')
68 self._coefficients[symbol] = coefficient
69 self._coefficients = OrderedDict(sorted(self._coefficients.items()))
70 if isinstance(constant, Constant):
71 constant = constant.constant
72 if not isinstance(constant, numbers.Rational):
73 raise TypeError('constant must be a rational number '
74 'or a Constant instance')
75 self._constant = constant
76 self._symbols = tuple(self._coefficients)
77 self._dimension = len(self._symbols)
78 self._hash = hash((tuple(self._coefficients.items()), self._constant))
79 return self
80
81 def coefficient(self, symbol):
82 symbol = symbolname(symbol)
83 try:
84 return self._coefficients[symbol]
85 except KeyError:
86 return 0
87
88 __getitem__ = coefficient
89
90 def coefficients(self):
91 yield from self._coefficients.items()
92
93 @property
94 def constant(self):
95 return self._constant
96
97 @property
98 def symbols(self):
99 return self._symbols
100
101 @property
102 def dimension(self):
103 return self._dimension
104
105 def __hash__(self):
106 return self._hash
107
108 def isconstant(self):
109 return False
110
111 def issymbol(self):
112 return False
113
114 def values(self):
115 for symbol in self.symbols:
116 yield self.coefficient(symbol)
117 yield self.constant
118
119 def __bool__(self):
120 return True
121
122 def __pos__(self):
123 return self
124
125 def __neg__(self):
126 return self * -1
127
128 @_polymorphic
129 def __add__(self, other):
130 coefficients = dict(self.coefficients())
131 for symbol, coefficient in other.coefficients():
132 if symbol in coefficients:
133 coefficients[symbol] += coefficient
134 else:
135 coefficients[symbol] = coefficient
136 constant = self.constant + other.constant
137 return Expression(coefficients, constant)
138
139 __radd__ = __add__
140
141 @_polymorphic
142 def __sub__(self, other):
143 coefficients = dict(self.coefficients())
144 for symbol, coefficient in other.coefficients():
145 if symbol in coefficients:
146 coefficients[symbol] -= coefficient
147 else:
148 coefficients[symbol] = -coefficient
149 constant = self.constant - other.constant
150 return Expression(coefficients, constant)
151
152 def __rsub__(self, other):
153 return -(self - other)
154
155 @_polymorphic
156 def __mul__(self, other):
157 if other.isconstant():
158 coefficients = dict(self.coefficients())
159 for symbol in coefficients:
160 coefficients[symbol] *= other.constant
161 constant = self.constant * other.constant
162 return Expression(coefficients, constant)
163 if isinstance(other, Expression) and not self.isconstant():
164 raise ValueError('non-linear expression: '
165 '{} * {}'.format(self._parenstr(), other._parenstr()))
166 return NotImplemented
167
168 __rmul__ = __mul__
169
170 @_polymorphic
171 def __truediv__(self, other):
172 if other.isconstant():
173 coefficients = dict(self.coefficients())
174 for symbol in coefficients:
175 coefficients[symbol] = \
176 Fraction(coefficients[symbol], other.constant)
177 constant = Fraction(self.constant, other.constant)
178 return Expression(coefficients, constant)
179 if isinstance(other, Expression):
180 raise ValueError('non-linear expression: '
181 '{} / {}'.format(self._parenstr(), other._parenstr()))
182 return NotImplemented
183
184 def __rtruediv__(self, other):
185 if isinstance(other, self):
186 if self.isconstant():
187 constant = Fraction(other, self.constant)
188 return Expression(constant=constant)
189 else:
190 raise ValueError('non-linear expression: '
191 '{} / {}'.format(other._parenstr(), self._parenstr()))
192 return NotImplemented
193
194 @_polymorphic
195 def __eq__(self, other):
196 # "normal" equality
197 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
198 return isinstance(other, Expression) and \
199 self._coefficients == other._coefficients and \
200 self.constant == other.constant
201
202 @_polymorphic
203 def __le__(self, other):
204 from .polyhedra import Le
205 return Le(self, other)
206
207 @_polymorphic
208 def __lt__(self, other):
209 from .polyhedra import Lt
210 return Lt(self, other)
211
212 @_polymorphic
213 def __ge__(self, other):
214 from .polyhedra import Ge
215 return Ge(self, other)
216
217 @_polymorphic
218 def __gt__(self, other):
219 from .polyhedra import Gt
220 return Gt(self, other)
221
222 def _toint(self):
223 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
224 [value.denominator for value in self.values()])
225 return self * lcm
226
227 @classmethod
228 def _fromast(cls, node):
229 if isinstance(node, ast.Module) and len(node.body) == 1:
230 return cls._fromast(node.body[0])
231 elif isinstance(node, ast.Expr):
232 return cls._fromast(node.value)
233 elif isinstance(node, ast.Name):
234 return Symbol(node.id)
235 elif isinstance(node, ast.Num):
236 return Constant(node.n)
237 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
238 return -cls._fromast(node.operand)
239 elif isinstance(node, ast.BinOp):
240 left = cls._fromast(node.left)
241 right = cls._fromast(node.right)
242 if isinstance(node.op, ast.Add):
243 return left + right
244 elif isinstance(node.op, ast.Sub):
245 return left - right
246 elif isinstance(node.op, ast.Mult):
247 return left * right
248 elif isinstance(node.op, ast.Div):
249 return left / right
250 raise SyntaxError('invalid syntax')
251
252 def subs(self, symbol, expression=None):
253 if expression is None:
254 if isinstance(symbol, dict):
255 symbol = symbol.items()
256 substitutions = symbol
257 else:
258 substitutions = [(symbol, expression)]
259 result = self
260 for symbol, expression in substitutions:
261 symbol = symbolname(symbol)
262 result = result._subs(symbol, expression)
263 return result
264
265 def _subs(self, symbol, expression):
266 coefficients = {name: coefficient
267 for name, coefficient in self.coefficients()
268 if name != symbol}
269 constant = self.constant
270 coefficient = self.coefficient(symbol)
271 result = Expression(coefficients, self.constant)
272 result += coefficient * expression
273 return result
274
275 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
276
277 @classmethod
278 def fromstring(cls, string):
279 # add implicit multiplication operators, e.g. '5x' -> '5*x'
280 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
281 tree = ast.parse(string, 'eval')
282 return cls._fromast(tree)
283
284 def __str__(self):
285 string = ''
286 i = 0
287 for symbol in self.symbols:
288 coefficient = self.coefficient(symbol)
289 if coefficient == 1:
290 if i == 0:
291 string += symbol
292 else:
293 string += ' + {}'.format(symbol)
294 elif coefficient == -1:
295 if i == 0:
296 string += '-{}'.format(symbol)
297 else:
298 string += ' - {}'.format(symbol)
299 else:
300 if i == 0:
301 string += '{}*{}'.format(coefficient, symbol)
302 elif coefficient > 0:
303 string += ' + {}*{}'.format(coefficient, symbol)
304 else:
305 assert coefficient < 0
306 coefficient *= -1
307 string += ' - {}*{}'.format(coefficient, symbol)
308 i += 1
309 constant = self.constant
310 if constant != 0 and i == 0:
311 string += '{}'.format(constant)
312 elif constant > 0:
313 string += ' + {}'.format(constant)
314 elif constant < 0:
315 constant *= -1
316 string += ' - {}'.format(constant)
317 if string == '':
318 string = '0'
319 return string
320
321 def _parenstr(self, always=False):
322 string = str(self)
323 if not always and (self.isconstant() or self.issymbol()):
324 return string
325 else:
326 return '({})'.format(string)
327
328 def __repr__(self):
329 return '{}({!r})'.format(self.__class__.__name__, str(self))
330
331 @classmethod
332 def fromsympy(cls, expr):
333 import sympy
334 coefficients = {}
335 constant = 0
336 for symbol, coefficient in expr.as_coefficients_dict().items():
337 coefficient = Fraction(coefficient.p, coefficient.q)
338 if symbol == sympy.S.One:
339 constant = coefficient
340 elif isinstance(symbol, sympy.Symbol):
341 symbol = symbol.name
342 coefficients[symbol] = coefficient
343 else:
344 raise ValueError('non-linear expression: {!r}'.format(expr))
345 return cls(coefficients, constant)
346
347 def tosympy(self):
348 import sympy
349 expr = 0
350 for symbol, coefficient in self.coefficients():
351 term = coefficient * sympy.Symbol(symbol)
352 expr += term
353 expr += self.constant
354 return expr
355
356
357 class Symbol(Expression):
358
359 __slots__ = (
360 '_name',
361 '_hash',
362 )
363
364 def __new__(cls, name):
365 name = symbolname(name)
366 self = object().__new__(cls)
367 self._name = name
368 self._hash = hash(self._name)
369 return self
370
371 @property
372 def name(self):
373 return self._name
374
375 def __hash__(self):
376 return self._hash
377
378 def coefficient(self, symbol):
379 symbol = symbolname(symbol)
380 if symbol == self.name:
381 return 1
382 else:
383 return 0
384
385 def coefficients(self):
386 yield self.name, 1
387
388 @property
389 def constant(self):
390 return 0
391
392 @property
393 def symbols(self):
394 return self.name,
395
396 @property
397 def dimension(self):
398 return 1
399
400 def issymbol(self):
401 return True
402
403 def __eq__(self, other):
404 return isinstance(other, Symbol) and self.name == other.name
405
406 @classmethod
407 def _fromast(cls, node):
408 if isinstance(node, ast.Module) and len(node.body) == 1:
409 return cls._fromast(node.body[0])
410 elif isinstance(node, ast.Expr):
411 return cls._fromast(node.value)
412 elif isinstance(node, ast.Name):
413 return Symbol(node.id)
414 raise SyntaxError('invalid syntax')
415
416 def __repr__(self):
417 return '{}({!r})'.format(self.__class__.__name__, self._name)
418
419 @classmethod
420 def fromsympy(cls, expr):
421 import sympy
422 if isinstance(expr, sympy.Symbol):
423 return cls(expr.name)
424 else:
425 raise TypeError('expr must be a sympy.Symbol instance')
426
427
428 def symbols(names):
429 if isinstance(names, str):
430 names = names.replace(',', ' ').split()
431 return (Symbol(name) for name in names)
432
433 def symbolname(symbol):
434 if isinstance(symbol, str):
435 return symbol.strip()
436 elif isinstance(symbol, Symbol):
437 return symbol.name
438 else:
439 raise TypeError('symbol must be a string or a Symbol instance')
440
441 def symbolnames(symbols):
442 if isinstance(symbols, str):
443 return symbols.replace(',', ' ').split()
444 return tuple(symbolname(symbol) for symbol in symbols)
445
446
447 class Constant(Expression):
448
449 __slots__ = (
450 '_constant',
451 '_hash',
452 )
453
454 def __new__(cls, numerator=0, denominator=None):
455 self = object().__new__(cls)
456 if denominator is None and isinstance(numerator, Constant):
457 self._constant = numerator.constant
458 else:
459 self._constant = Fraction(numerator, denominator)
460 self._hash = hash(self._constant)
461 return self
462
463 def __hash__(self):
464 return self._hash
465
466 def coefficient(self, symbol):
467 symbol = symbolname(symbol)
468 return 0
469
470 def coefficients(self):
471 yield from []
472
473 @property
474 def symbols(self):
475 return ()
476
477 @property
478 def dimension(self):
479 return 0
480
481 def isconstant(self):
482 return True
483
484 @_polymorphic
485 def __eq__(self, other):
486 return isinstance(other, Constant) and self.constant == other.constant
487
488 def __bool__(self):
489 return self.constant != 0
490
491 @classmethod
492 def fromstring(cls, string):
493 if isinstance(string, str):
494 return Constant(Fraction(string))
495 else:
496 raise TypeError('string must be a string instance')
497
498 def __repr__(self):
499 if self.constant.denominator == 1:
500 return '{}({!r})'.format(self.__class__.__name__,
501 self.constant.numerator)
502 else:
503 return '{}({!r}, {!r})'.format(self.__class__.__name__,
504 self.constant.numerator, self.constant.denominator)
505
506 @classmethod
507 def fromsympy(cls, expr):
508 import sympy
509 if isinstance(expr, sympy.Rational):
510 return cls(expr.p, expr.q)
511 elif isinstance(expr, numbers.Rational):
512 return cls(expr)
513 else:
514 raise TypeError('expr must be a sympy.Rational instance')