Improve tests involving iterators
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict, defaultdict
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'Dummy', 'symbols',
13 'Rational',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Rational(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 )
40
41 def __new__(cls, coefficients=None, constant=0):
42 if isinstance(coefficients, str):
43 if constant:
44 raise TypeError('too many arguments')
45 return Expression.fromstring(coefficients)
46 if coefficients is None:
47 return Rational(constant)
48 if isinstance(coefficients, dict):
49 coefficients = coefficients.items()
50 for symbol, coefficient in coefficients:
51 if not isinstance(symbol, Symbol):
52 raise TypeError('symbols must be Symbol instances')
53 coefficients = [(symbol, coefficient)
54 for symbol, coefficient in coefficients if coefficient != 0]
55 if len(coefficients) == 0:
56 return Rational(constant)
57 if len(coefficients) == 1 and constant == 0:
58 symbol, coefficient = coefficients[0]
59 if coefficient == 1:
60 return symbol
61 self = object().__new__(cls)
62 self._coefficients = OrderedDict()
63 for symbol, coefficient in sorted(coefficients,
64 key=lambda item: item[0].sortkey()):
65 if isinstance(coefficient, Rational):
66 coefficient = coefficient.constant
67 if not isinstance(coefficient, numbers.Rational):
68 raise TypeError('coefficients must be rational numbers '
69 'or Rational instances')
70 self._coefficients[symbol] = coefficient
71 if isinstance(constant, Rational):
72 constant = constant.constant
73 if not isinstance(constant, numbers.Rational):
74 raise TypeError('constant must be a rational number '
75 'or a Rational instance')
76 self._constant = constant
77 self._symbols = tuple(self._coefficients)
78 self._dimension = len(self._symbols)
79 return self
80
81 def coefficient(self, symbol):
82 if not isinstance(symbol, Symbol):
83 raise TypeError('symbol must be a Symbol instance')
84 try:
85 return self._coefficients[symbol]
86 except KeyError:
87 return 0
88
89 __getitem__ = coefficient
90
91 def coefficients(self):
92 yield from self._coefficients.items()
93
94 @property
95 def constant(self):
96 return self._constant
97
98 @property
99 def symbols(self):
100 return self._symbols
101
102 @property
103 def dimension(self):
104 return self._dimension
105
106 def __hash__(self):
107 return hash((tuple(self._coefficients.items()), self._constant))
108
109 def isconstant(self):
110 return False
111
112 def issymbol(self):
113 return False
114
115 def values(self):
116 yield from self._coefficients.values()
117 yield self.constant
118
119 def __bool__(self):
120 return True
121
122 def __pos__(self):
123 return self
124
125 def __neg__(self):
126 return self * -1
127
128 @_polymorphic
129 def __add__(self, other):
130 coefficients = defaultdict(Rational, self.coefficients())
131 for symbol, coefficient in other.coefficients():
132 coefficients[symbol] += coefficient
133 constant = self.constant + other.constant
134 return Expression(coefficients, constant)
135
136 __radd__ = __add__
137
138 @_polymorphic
139 def __sub__(self, other):
140 coefficients = defaultdict(Rational, self.coefficients())
141 for symbol, coefficient in other.coefficients():
142 coefficients[symbol] -= coefficient
143 constant = self.constant - other.constant
144 return Expression(coefficients, constant)
145
146 def __rsub__(self, other):
147 return -(self - other)
148
149 @_polymorphic
150 def __mul__(self, other):
151 if other.isconstant():
152 coefficients = dict(self.coefficients())
153 for symbol in coefficients:
154 coefficients[symbol] *= other.constant
155 constant = self.constant * other.constant
156 return Expression(coefficients, constant)
157 if isinstance(other, Expression) and not self.isconstant():
158 raise ValueError('non-linear expression: '
159 '{} * {}'.format(self._parenstr(), other._parenstr()))
160 return NotImplemented
161
162 __rmul__ = __mul__
163
164 @_polymorphic
165 def __truediv__(self, other):
166 if other.isconstant():
167 coefficients = dict(self.coefficients())
168 for symbol in coefficients:
169 coefficients[symbol] = Rational(coefficients[symbol], other.constant)
170 constant = Rational(self.constant, other.constant)
171 return Expression(coefficients, constant)
172 if isinstance(other, Expression):
173 raise ValueError('non-linear expression: '
174 '{} / {}'.format(self._parenstr(), other._parenstr()))
175 return NotImplemented
176
177 def __rtruediv__(self, other):
178 if isinstance(other, self):
179 if self.isconstant():
180 return Rational(other, self.constant)
181 else:
182 raise ValueError('non-linear expression: '
183 '{} / {}'.format(other._parenstr(), self._parenstr()))
184 return NotImplemented
185
186 @_polymorphic
187 def __eq__(self, other):
188 # "normal" equality
189 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
190 return isinstance(other, Expression) and \
191 self._coefficients == other._coefficients and \
192 self.constant == other.constant
193
194 @_polymorphic
195 def __le__(self, other):
196 from .polyhedra import Le
197 return Le(self, other)
198
199 @_polymorphic
200 def __lt__(self, other):
201 from .polyhedra import Lt
202 return Lt(self, other)
203
204 @_polymorphic
205 def __ge__(self, other):
206 from .polyhedra import Ge
207 return Ge(self, other)
208
209 @_polymorphic
210 def __gt__(self, other):
211 from .polyhedra import Gt
212 return Gt(self, other)
213
214 def scaleint(self):
215 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
216 [value.denominator for value in self.values()])
217 return self * lcm
218
219 def subs(self, symbol, expression=None):
220 if expression is None:
221 if isinstance(symbol, dict):
222 symbol = symbol.items()
223 substitutions = symbol
224 else:
225 substitutions = [(symbol, expression)]
226 result = self
227 for symbol, expression in substitutions:
228 coefficients = [(othersymbol, coefficient)
229 for othersymbol, coefficient in result.coefficients()
230 if othersymbol != symbol]
231 coefficient = result.coefficient(symbol)
232 constant = result.constant
233 result = Expression(coefficients, constant) + coefficient*expression
234 return result
235
236 @classmethod
237 def _fromast(cls, node):
238 if isinstance(node, ast.Module) and len(node.body) == 1:
239 return cls._fromast(node.body[0])
240 elif isinstance(node, ast.Expr):
241 return cls._fromast(node.value)
242 elif isinstance(node, ast.Name):
243 return Symbol(node.id)
244 elif isinstance(node, ast.Num):
245 return Rational(node.n)
246 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
247 return -cls._fromast(node.operand)
248 elif isinstance(node, ast.BinOp):
249 left = cls._fromast(node.left)
250 right = cls._fromast(node.right)
251 if isinstance(node.op, ast.Add):
252 return left + right
253 elif isinstance(node.op, ast.Sub):
254 return left - right
255 elif isinstance(node.op, ast.Mult):
256 return left * right
257 elif isinstance(node.op, ast.Div):
258 return left / right
259 raise SyntaxError('invalid syntax')
260
261 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
262
263 @classmethod
264 def fromstring(cls, string):
265 # add implicit multiplication operators, e.g. '5x' -> '5*x'
266 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
267 tree = ast.parse(string, 'eval')
268 return cls._fromast(tree)
269
270 def __repr__(self):
271 string = ''
272 i = 0
273 for symbol in self.symbols:
274 coefficient = self.coefficient(symbol)
275 if coefficient == 1:
276 if i == 0:
277 string += symbol.name
278 else:
279 string += ' + {}'.format(symbol)
280 elif coefficient == -1:
281 if i == 0:
282 string += '-{}'.format(symbol)
283 else:
284 string += ' - {}'.format(symbol)
285 else:
286 if i == 0:
287 string += '{}*{}'.format(coefficient, symbol)
288 elif coefficient > 0:
289 string += ' + {}*{}'.format(coefficient, symbol)
290 else:
291 assert coefficient < 0
292 coefficient *= -1
293 string += ' - {}*{}'.format(coefficient, symbol)
294 i += 1
295 constant = self.constant
296 if constant != 0 and i == 0:
297 string += '{}'.format(constant)
298 elif constant > 0:
299 string += ' + {}'.format(constant)
300 elif constant < 0:
301 constant *= -1
302 string += ' - {}'.format(constant)
303 if string == '':
304 string = '0'
305 return string
306
307 def _parenstr(self, always=False):
308 string = str(self)
309 if not always and (self.isconstant() or self.issymbol()):
310 return string
311 else:
312 return '({})'.format(string)
313
314 @classmethod
315 def fromsympy(cls, expr):
316 import sympy
317 coefficients = []
318 constant = 0
319 for symbol, coefficient in expr.as_coefficients_dict().items():
320 coefficient = Fraction(coefficient.p, coefficient.q)
321 if symbol == sympy.S.One:
322 constant = coefficient
323 elif isinstance(symbol, sympy.Symbol):
324 symbol = Symbol(symbol.name)
325 coefficients.append((symbol, coefficient))
326 else:
327 raise ValueError('non-linear expression: {!r}'.format(expr))
328 return Expression(coefficients, constant)
329
330 def tosympy(self):
331 import sympy
332 expr = 0
333 for symbol, coefficient in self.coefficients():
334 term = coefficient * sympy.Symbol(symbol.name)
335 expr += term
336 expr += self.constant
337 return expr
338
339
340 class Symbol(Expression):
341
342 __slots__ = (
343 '_name',
344 )
345
346 def __new__(cls, name):
347 if not isinstance(name, str):
348 raise TypeError('name must be a string')
349 self = object().__new__(cls)
350 self._name = name.strip()
351 return self
352
353 @property
354 def name(self):
355 return self._name
356
357 def __hash__(self):
358 return hash(self.sortkey())
359
360 def coefficient(self, symbol):
361 if not isinstance(symbol, Symbol):
362 raise TypeError('symbol must be a Symbol instance')
363 if symbol == self:
364 return 1
365 else:
366 return 0
367
368 def coefficients(self):
369 yield self, 1
370
371 @property
372 def constant(self):
373 return 0
374
375 @property
376 def symbols(self):
377 return self,
378
379 @property
380 def dimension(self):
381 return 1
382
383 def sortkey(self):
384 return self.name,
385
386 def issymbol(self):
387 return True
388
389 def values(self):
390 yield 1
391
392 def __eq__(self, other):
393 return not isinstance(other, Dummy) and isinstance(other, Symbol) \
394 and self.name == other.name
395
396 def asdummy(self):
397 return Dummy(self.name)
398
399 @classmethod
400 def _fromast(cls, node):
401 if isinstance(node, ast.Module) and len(node.body) == 1:
402 return cls._fromast(node.body[0])
403 elif isinstance(node, ast.Expr):
404 return cls._fromast(node.value)
405 elif isinstance(node, ast.Name):
406 return Symbol(node.id)
407 raise SyntaxError('invalid syntax')
408
409 @classmethod
410 def fromsympy(cls, expr):
411 import sympy
412 if isinstance(expr, sympy.Symbol):
413 return Symbol(expr.name)
414 else:
415 raise TypeError('expr must be a sympy.Symbol instance')
416
417
418 class Dummy(Symbol):
419
420 __slots__ = (
421 '_name',
422 '_index',
423 )
424
425 _count = 0
426
427 def __new__(cls, name=None):
428 if name is None:
429 name = 'Dummy_{}'.format(Dummy._count)
430 self = object().__new__(cls)
431 self._name = name.strip()
432 self._index = Dummy._count
433 Dummy._count += 1
434 return self
435
436 def __hash__(self):
437 return hash(self.sortkey())
438
439 def sortkey(self):
440 return self._name, self._index
441
442 def __eq__(self, other):
443 return isinstance(other, Dummy) and self._index == other._index
444
445
446 def symbols(names):
447 if isinstance(names, str):
448 names = names.replace(',', ' ').split()
449 return tuple(Symbol(name) for name in names)
450
451
452 class Rational(Expression):
453
454 __slots__ = (
455 '_constant',
456 )
457
458 def __new__(cls, numerator=0, denominator=None):
459 self = object().__new__(cls)
460 if denominator is None and isinstance(numerator, Rational):
461 self._constant = numerator.constant
462 else:
463 self._constant = Fraction(numerator, denominator)
464 return self
465
466 def __hash__(self):
467 return hash(self.constant)
468
469 def coefficient(self, symbol):
470 if not isinstance(symbol, Symbol):
471 raise TypeError('symbol must be a Symbol instance')
472 return 0
473
474 def coefficients(self):
475 yield from ()
476
477 @property
478 def symbols(self):
479 return ()
480
481 @property
482 def dimension(self):
483 return 0
484
485 def isconstant(self):
486 return True
487
488 def values(self):
489 yield self._constant
490
491 @_polymorphic
492 def __eq__(self, other):
493 return isinstance(other, Rational) and self.constant == other.constant
494
495 def __bool__(self):
496 return self.constant != 0
497
498 @classmethod
499 def fromstring(cls, string):
500 if not isinstance(string, str):
501 raise TypeError('string must be a string instance')
502 return Rational(Fraction(string))
503
504 @classmethod
505 def fromsympy(cls, expr):
506 import sympy
507 if isinstance(expr, sympy.Rational):
508 return Rational(expr.p, expr.q)
509 elif isinstance(expr, numbers.Rational):
510 return Rational(expr)
511 else:
512 raise TypeError('expr must be a sympy.Rational instance')