Symplify TypeError messages in Expression.__new__
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict, defaultdict, Mapping
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'Dummy', 'symbols',
13 'Rational',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Rational(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 )
40
41 def __new__(cls, coefficients=None, constant=0):
42 if isinstance(coefficients, str):
43 if constant:
44 raise TypeError('too many arguments')
45 return Expression.fromstring(coefficients)
46 if coefficients is None:
47 return Rational(constant)
48 if isinstance(coefficients, Mapping):
49 coefficients = coefficients.items()
50 for symbol, coefficient in coefficients:
51 if not isinstance(symbol, Symbol):
52 raise TypeError('symbols must be Symbol instances')
53 coefficients = [(symbol, coefficient)
54 for symbol, coefficient in coefficients if coefficient != 0]
55 if len(coefficients) == 0:
56 return Rational(constant)
57 if len(coefficients) == 1 and constant == 0:
58 symbol, coefficient = coefficients[0]
59 if coefficient == 1:
60 return symbol
61 self = object().__new__(cls)
62 self._coefficients = OrderedDict()
63 for symbol, coefficient in sorted(coefficients,
64 key=lambda item: item[0].sortkey()):
65 if isinstance(coefficient, Rational):
66 coefficient = coefficient.constant
67 if not isinstance(coefficient, numbers.Rational):
68 raise TypeError('coefficients must be Rational instances')
69 self._coefficients[symbol] = coefficient
70 if isinstance(constant, Rational):
71 constant = constant.constant
72 if not isinstance(constant, numbers.Rational):
73 raise TypeError('constant must be a Rational instance')
74 self._constant = constant
75 self._symbols = tuple(self._coefficients)
76 self._dimension = len(self._symbols)
77 return self
78
79 def coefficient(self, symbol):
80 if not isinstance(symbol, Symbol):
81 raise TypeError('symbol must be a Symbol instance')
82 try:
83 return self._coefficients[symbol]
84 except KeyError:
85 return 0
86
87 __getitem__ = coefficient
88
89 def coefficients(self):
90 yield from self._coefficients.items()
91
92 @property
93 def constant(self):
94 return self._constant
95
96 @property
97 def symbols(self):
98 return self._symbols
99
100 @property
101 def dimension(self):
102 return self._dimension
103
104 def __hash__(self):
105 return hash((tuple(self._coefficients.items()), self._constant))
106
107 def isconstant(self):
108 return False
109
110 def issymbol(self):
111 return False
112
113 def values(self):
114 yield from self._coefficients.values()
115 yield self.constant
116
117 def __bool__(self):
118 return True
119
120 def __pos__(self):
121 return self
122
123 def __neg__(self):
124 return self * -1
125
126 @_polymorphic
127 def __add__(self, other):
128 coefficients = defaultdict(Rational, self.coefficients())
129 for symbol, coefficient in other.coefficients():
130 coefficients[symbol] += coefficient
131 constant = self.constant + other.constant
132 return Expression(coefficients, constant)
133
134 __radd__ = __add__
135
136 @_polymorphic
137 def __sub__(self, other):
138 coefficients = defaultdict(Rational, self.coefficients())
139 for symbol, coefficient in other.coefficients():
140 coefficients[symbol] -= coefficient
141 constant = self.constant - other.constant
142 return Expression(coefficients, constant)
143
144 def __rsub__(self, other):
145 return -(self - other)
146
147 @_polymorphic
148 def __mul__(self, other):
149 if other.isconstant():
150 coefficients = dict(self.coefficients())
151 for symbol in coefficients:
152 coefficients[symbol] *= other.constant
153 constant = self.constant * other.constant
154 return Expression(coefficients, constant)
155 if isinstance(other, Expression) and not self.isconstant():
156 raise ValueError('non-linear expression: '
157 '{} * {}'.format(self._parenstr(), other._parenstr()))
158 return NotImplemented
159
160 __rmul__ = __mul__
161
162 @_polymorphic
163 def __truediv__(self, other):
164 if other.isconstant():
165 coefficients = dict(self.coefficients())
166 for symbol in coefficients:
167 coefficients[symbol] = Rational(coefficients[symbol], other.constant)
168 constant = Rational(self.constant, other.constant)
169 return Expression(coefficients, constant)
170 if isinstance(other, Expression):
171 raise ValueError('non-linear expression: '
172 '{} / {}'.format(self._parenstr(), other._parenstr()))
173 return NotImplemented
174
175 def __rtruediv__(self, other):
176 if isinstance(other, self):
177 if self.isconstant():
178 return Rational(other, self.constant)
179 else:
180 raise ValueError('non-linear expression: '
181 '{} / {}'.format(other._parenstr(), self._parenstr()))
182 return NotImplemented
183
184 @_polymorphic
185 def __eq__(self, other):
186 # "normal" equality
187 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
188 return isinstance(other, Expression) and \
189 self._coefficients == other._coefficients and \
190 self.constant == other.constant
191
192 @_polymorphic
193 def __le__(self, other):
194 from .polyhedra import Le
195 return Le(self, other)
196
197 @_polymorphic
198 def __lt__(self, other):
199 from .polyhedra import Lt
200 return Lt(self, other)
201
202 @_polymorphic
203 def __ge__(self, other):
204 from .polyhedra import Ge
205 return Ge(self, other)
206
207 @_polymorphic
208 def __gt__(self, other):
209 from .polyhedra import Gt
210 return Gt(self, other)
211
212 def scaleint(self):
213 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
214 [value.denominator for value in self.values()])
215 return self * lcm
216
217 def subs(self, symbol, expression=None):
218 if expression is None:
219 if isinstance(symbol, Mapping):
220 symbol = symbol.items()
221 substitutions = symbol
222 else:
223 substitutions = [(symbol, expression)]
224 result = self
225 for symbol, expression in substitutions:
226 coefficients = [(othersymbol, coefficient)
227 for othersymbol, coefficient in result.coefficients()
228 if othersymbol != symbol]
229 coefficient = result.coefficient(symbol)
230 constant = result.constant
231 result = Expression(coefficients, constant) + coefficient*expression
232 return result
233
234 @classmethod
235 def _fromast(cls, node):
236 if isinstance(node, ast.Module) and len(node.body) == 1:
237 return cls._fromast(node.body[0])
238 elif isinstance(node, ast.Expr):
239 return cls._fromast(node.value)
240 elif isinstance(node, ast.Name):
241 return Symbol(node.id)
242 elif isinstance(node, ast.Num):
243 return Rational(node.n)
244 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
245 return -cls._fromast(node.operand)
246 elif isinstance(node, ast.BinOp):
247 left = cls._fromast(node.left)
248 right = cls._fromast(node.right)
249 if isinstance(node.op, ast.Add):
250 return left + right
251 elif isinstance(node.op, ast.Sub):
252 return left - right
253 elif isinstance(node.op, ast.Mult):
254 return left * right
255 elif isinstance(node.op, ast.Div):
256 return left / right
257 raise SyntaxError('invalid syntax')
258
259 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
260
261 @classmethod
262 def fromstring(cls, string):
263 # add implicit multiplication operators, e.g. '5x' -> '5*x'
264 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
265 tree = ast.parse(string, 'eval')
266 return cls._fromast(tree)
267
268 def __repr__(self):
269 string = ''
270 for i, (symbol, coefficient) in enumerate(self.coefficients()):
271 if coefficient == 1:
272 string += '' if i == 0 else ' + '
273 string += '{!r}'.format(symbol)
274 elif coefficient == -1:
275 string += '-' if i == 0 else ' - '
276 string += '{!r}'.format(symbol)
277 else:
278 if i == 0:
279 string += '{}*{!r}'.format(coefficient, symbol)
280 elif coefficient > 0:
281 string += ' + {}*{!r}'.format(coefficient, symbol)
282 else:
283 string += ' - {}*{!r}'.format(-coefficient, symbol)
284 constant = self.constant
285 if len(string) == 0:
286 string += '{}'.format(constant)
287 elif constant > 0:
288 string += ' + {}'.format(constant)
289 elif constant < 0:
290 string += ' - {}'.format(-constant)
291 return string
292
293 def _parenstr(self, always=False):
294 string = str(self)
295 if not always and (self.isconstant() or self.issymbol()):
296 return string
297 else:
298 return '({})'.format(string)
299
300 @classmethod
301 def fromsympy(cls, expr):
302 import sympy
303 coefficients = []
304 constant = 0
305 for symbol, coefficient in expr.as_coefficients_dict().items():
306 coefficient = Fraction(coefficient.p, coefficient.q)
307 if symbol == sympy.S.One:
308 constant = coefficient
309 elif isinstance(symbol, sympy.Symbol):
310 symbol = Symbol(symbol.name)
311 coefficients.append((symbol, coefficient))
312 else:
313 raise ValueError('non-linear expression: {!r}'.format(expr))
314 return Expression(coefficients, constant)
315
316 def tosympy(self):
317 import sympy
318 expr = 0
319 for symbol, coefficient in self.coefficients():
320 term = coefficient * sympy.Symbol(symbol.name)
321 expr += term
322 expr += self.constant
323 return expr
324
325
326 class Symbol(Expression):
327
328 __slots__ = (
329 '_name',
330 )
331
332 def __new__(cls, name):
333 if not isinstance(name, str):
334 raise TypeError('name must be a string')
335 self = object().__new__(cls)
336 self._name = name.strip()
337 return self
338
339 @property
340 def name(self):
341 return self._name
342
343 def __hash__(self):
344 return hash(self.sortkey())
345
346 def coefficient(self, symbol):
347 if not isinstance(symbol, Symbol):
348 raise TypeError('symbol must be a Symbol instance')
349 if symbol == self:
350 return 1
351 else:
352 return 0
353
354 def coefficients(self):
355 yield self, 1
356
357 @property
358 def constant(self):
359 return 0
360
361 @property
362 def symbols(self):
363 return self,
364
365 @property
366 def dimension(self):
367 return 1
368
369 def sortkey(self):
370 return self.name,
371
372 def issymbol(self):
373 return True
374
375 def values(self):
376 yield 1
377
378 def __eq__(self, other):
379 return not isinstance(other, Dummy) and isinstance(other, Symbol) \
380 and self.name == other.name
381
382 def asdummy(self):
383 return Dummy(self.name)
384
385 @classmethod
386 def _fromast(cls, node):
387 if isinstance(node, ast.Module) and len(node.body) == 1:
388 return cls._fromast(node.body[0])
389 elif isinstance(node, ast.Expr):
390 return cls._fromast(node.value)
391 elif isinstance(node, ast.Name):
392 return Symbol(node.id)
393 raise SyntaxError('invalid syntax')
394
395 def __repr__(self):
396 return self.name
397
398 @classmethod
399 def fromsympy(cls, expr):
400 import sympy
401 if isinstance(expr, sympy.Symbol):
402 return cls(expr.name)
403 else:
404 raise TypeError('expr must be a sympy.Symbol instance')
405
406
407 class Dummy(Symbol):
408
409 __slots__ = (
410 '_name',
411 '_index',
412 )
413
414 _count = 0
415
416 def __new__(cls, name=None):
417 if name is None:
418 name = 'Dummy_{}'.format(Dummy._count)
419 self = object().__new__(cls)
420 self._name = name.strip()
421 self._index = Dummy._count
422 Dummy._count += 1
423 return self
424
425 def __hash__(self):
426 return hash(self.sortkey())
427
428 def sortkey(self):
429 return self._name, self._index
430
431 def __eq__(self, other):
432 return isinstance(other, Dummy) and self._index == other._index
433
434 def __repr__(self):
435 return '_{}'.format(self.name)
436
437
438 def symbols(names):
439 if isinstance(names, str):
440 names = names.replace(',', ' ').split()
441 return tuple(Symbol(name) for name in names)
442
443
444 class Rational(Expression):
445
446 __slots__ = (
447 '_constant',
448 )
449
450 def __new__(cls, numerator=0, denominator=None):
451 self = object().__new__(cls)
452 if denominator is None and isinstance(numerator, Rational):
453 self._constant = numerator.constant
454 else:
455 self._constant = Fraction(numerator, denominator)
456 return self
457
458 def __hash__(self):
459 return hash(self.constant)
460
461 def coefficient(self, symbol):
462 if not isinstance(symbol, Symbol):
463 raise TypeError('symbol must be a Symbol instance')
464 return 0
465
466 def coefficients(self):
467 yield from ()
468
469 @property
470 def symbols(self):
471 return ()
472
473 @property
474 def dimension(self):
475 return 0
476
477 def isconstant(self):
478 return True
479
480 def values(self):
481 yield self._constant
482
483 @_polymorphic
484 def __eq__(self, other):
485 return isinstance(other, Rational) and self.constant == other.constant
486
487 def __bool__(self):
488 return self.constant != 0
489
490 @classmethod
491 def fromstring(cls, string):
492 if not isinstance(string, str):
493 raise TypeError('string must be a string instance')
494 return Rational(Fraction(string))
495
496 @classmethod
497 def fromsympy(cls, expr):
498 import sympy
499 if isinstance(expr, sympy.Rational):
500 return Rational(expr.p, expr.q)
501 elif isinstance(expr, numbers.Rational):
502 return Rational(expr)
503 else:
504 raise TypeError('expr must be a sympy.Rational instance')