07d40052e2d240aefba921e84630d7e64a8c1dd0
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict, defaultdict, Mapping
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'Dummy', 'symbols',
13 'Rational',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Rational(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 def __new__(cls, coefficients=None, constant=0):
35 if isinstance(coefficients, str):
36 if constant != 0:
37 raise TypeError('too many arguments')
38 return Expression.fromstring(coefficients)
39 if coefficients is None:
40 return Rational(constant)
41 if isinstance(coefficients, Mapping):
42 coefficients = coefficients.items()
43 coefficients = list(coefficients)
44 for symbol, coefficient in coefficients:
45 if not isinstance(symbol, Symbol):
46 raise TypeError('symbols must be Symbol instances')
47 if not isinstance(coefficient, numbers.Rational):
48 raise TypeError('coefficients must be rational numbers')
49 if not isinstance(constant, numbers.Rational):
50 raise TypeError('constant must be a rational number')
51 if len(coefficients) == 0:
52 return Rational(constant)
53 if len(coefficients) == 1 and constant == 0:
54 symbol, coefficient = coefficients[0]
55 if coefficient == 1:
56 return symbol
57 coefficients = [(symbol, Fraction(coefficient))
58 for symbol, coefficient in coefficients if coefficient != 0]
59 coefficients.sort(key=lambda item: item[0].sortkey())
60 self = object().__new__(cls)
61 self._coefficients = OrderedDict(coefficients)
62 self._constant = Fraction(constant)
63 self._symbols = tuple(self._coefficients)
64 self._dimension = len(self._symbols)
65 return self
66
67 def coefficient(self, symbol):
68 if not isinstance(symbol, Symbol):
69 raise TypeError('symbol must be a Symbol instance')
70 return Rational(self._coefficients.get(symbol, 0))
71
72 __getitem__ = coefficient
73
74 def coefficients(self):
75 for symbol, coefficient in self._coefficients.items():
76 yield symbol, Rational(coefficient)
77
78 @property
79 def constant(self):
80 return Rational(self._constant)
81
82 @property
83 def symbols(self):
84 return self._symbols
85
86 @property
87 def dimension(self):
88 return self._dimension
89
90 def __hash__(self):
91 return hash((tuple(self._coefficients.items()), self._constant))
92
93 def isconstant(self):
94 return False
95
96 def issymbol(self):
97 return False
98
99 def values(self):
100 for coefficient in self._coefficients.values():
101 yield Rational(coefficient)
102 yield Rational(self._constant)
103
104 def __bool__(self):
105 return True
106
107 def __pos__(self):
108 return self
109
110 def __neg__(self):
111 return self * -1
112
113 @_polymorphic
114 def __add__(self, other):
115 coefficients = defaultdict(Fraction, self._coefficients)
116 for symbol, coefficient in other._coefficients.items():
117 coefficients[symbol] += coefficient
118 constant = self._constant + other._constant
119 return Expression(coefficients, constant)
120
121 __radd__ = __add__
122
123 @_polymorphic
124 def __sub__(self, other):
125 coefficients = defaultdict(Fraction, self._coefficients)
126 for symbol, coefficient in other._coefficients.items():
127 coefficients[symbol] -= coefficient
128 constant = self._constant - other._constant
129 return Expression(coefficients, constant)
130
131 @_polymorphic
132 def __rsub__(self, other):
133 return other - self
134
135 def __mul__(self, other):
136 if isinstance(other, numbers.Rational):
137 coefficients = ((symbol, coefficient * other)
138 for symbol, coefficient in self._coefficients.items())
139 constant = self._constant * other
140 return Expression(coefficients, constant)
141 return NotImplemented
142
143 __rmul__ = __mul__
144
145 def __truediv__(self, other):
146 if isinstance(other, numbers.Rational):
147 coefficients = ((symbol, coefficient / other)
148 for symbol, coefficient in self._coefficients.items())
149 constant = self._constant / other
150 return Expression(coefficients, constant)
151 return NotImplemented
152
153 @_polymorphic
154 def __eq__(self, other):
155 # returns a boolean, not a constraint
156 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
157 return isinstance(other, Expression) and \
158 self._coefficients == other._coefficients and \
159 self._constant == other._constant
160
161 def __le__(self, other):
162 from .polyhedra import Le
163 return Le(self, other)
164
165 def __lt__(self, other):
166 from .polyhedra import Lt
167 return Lt(self, other)
168
169 def __ge__(self, other):
170 from .polyhedra import Ge
171 return Ge(self, other)
172
173 def __gt__(self, other):
174 from .polyhedra import Gt
175 return Gt(self, other)
176
177 def scaleint(self):
178 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
179 [value.denominator for value in self.values()])
180 return self * lcm
181
182 def subs(self, symbol, expression=None):
183 if expression is None:
184 if isinstance(symbol, Mapping):
185 symbol = symbol.items()
186 substitutions = symbol
187 else:
188 substitutions = [(symbol, expression)]
189 result = self
190 for symbol, expression in substitutions:
191 if not isinstance(symbol, Symbol):
192 raise TypeError('symbols must be Symbol instances')
193 coefficients = [(othersymbol, coefficient)
194 for othersymbol, coefficient in result._coefficients.items()
195 if othersymbol != symbol]
196 coefficient = result._coefficients.get(symbol, 0)
197 constant = result._constant
198 result = Expression(coefficients, constant) + coefficient*expression
199 return result
200
201 @classmethod
202 def _fromast(cls, node):
203 if isinstance(node, ast.Module) and len(node.body) == 1:
204 return cls._fromast(node.body[0])
205 elif isinstance(node, ast.Expr):
206 return cls._fromast(node.value)
207 elif isinstance(node, ast.Name):
208 return Symbol(node.id)
209 elif isinstance(node, ast.Num):
210 return Rational(node.n)
211 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
212 return -cls._fromast(node.operand)
213 elif isinstance(node, ast.BinOp):
214 left = cls._fromast(node.left)
215 right = cls._fromast(node.right)
216 if isinstance(node.op, ast.Add):
217 return left + right
218 elif isinstance(node.op, ast.Sub):
219 return left - right
220 elif isinstance(node.op, ast.Mult):
221 return left * right
222 elif isinstance(node.op, ast.Div):
223 return left / right
224 raise SyntaxError('invalid syntax')
225
226 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
227
228 @classmethod
229 def fromstring(cls, string):
230 # add implicit multiplication operators, e.g. '5x' -> '5*x'
231 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
232 tree = ast.parse(string, 'eval')
233 return cls._fromast(tree)
234
235 def __repr__(self):
236 string = ''
237 for i, (symbol, coefficient) in enumerate(self.coefficients()):
238 if coefficient == 1:
239 if i != 0:
240 string += ' + '
241 elif coefficient == -1:
242 string += '-' if i == 0 else ' - '
243 elif i == 0:
244 string += '{}*'.format(coefficient)
245 elif coefficient > 0:
246 string += ' + {}*'.format(coefficient)
247 else:
248 string += ' - {}*'.format(-coefficient)
249 string += '{}'.format(symbol)
250 constant = self.constant
251 if len(string) == 0:
252 string += '{}'.format(constant)
253 elif constant > 0:
254 string += ' + {}'.format(constant)
255 elif constant < 0:
256 string += ' - {}'.format(-constant)
257 return string
258
259 def _repr_latex_(self):
260 string = ''
261 for i, (symbol, coefficient) in enumerate(self.coefficients()):
262 if coefficient == 1:
263 if i != 0:
264 string += ' + '
265 elif coefficient == -1:
266 string += '-' if i == 0 else ' - '
267 elif i == 0:
268 string += '{}'.format(coefficient._repr_latex_().strip('$'))
269 elif coefficient > 0:
270 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
271 elif coefficient < 0:
272 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
273 string += '{}'.format(symbol._repr_latex_().strip('$'))
274 constant = self.constant
275 if len(string) == 0:
276 string += '{}'.format(constant._repr_latex_().strip('$'))
277 elif constant > 0:
278 string += ' + {}'.format(constant._repr_latex_().strip('$'))
279 elif constant < 0:
280 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
281 return '$${}$$'.format(string)
282
283 def _parenstr(self, always=False):
284 string = str(self)
285 if not always and (self.isconstant() or self.issymbol()):
286 return string
287 else:
288 return '({})'.format(string)
289
290 @classmethod
291 def fromsympy(cls, expr):
292 import sympy
293 coefficients = []
294 constant = 0
295 for symbol, coefficient in expr.as_coefficients_dict().items():
296 coefficient = Fraction(coefficient.p, coefficient.q)
297 if symbol == sympy.S.One:
298 constant = coefficient
299 elif isinstance(symbol, sympy.Symbol):
300 symbol = Symbol(symbol.name)
301 coefficients.append((symbol, coefficient))
302 else:
303 raise ValueError('non-linear expression: {!r}'.format(expr))
304 return Expression(coefficients, constant)
305
306 def tosympy(self):
307 import sympy
308 expr = 0
309 for symbol, coefficient in self.coefficients():
310 term = coefficient * sympy.Symbol(symbol.name)
311 expr += term
312 expr += self.constant
313 return expr
314
315
316 class Symbol(Expression):
317
318 def __new__(cls, name):
319 if not isinstance(name, str):
320 raise TypeError('name must be a string')
321 self = object().__new__(cls)
322 self._name = name.strip()
323 self._coefficients = {self: Fraction(1)}
324 self._constant = Fraction(0)
325 self._symbols = (self,)
326 self._dimension = 1
327 return self
328
329 @property
330 def name(self):
331 return self._name
332
333 def __hash__(self):
334 return hash(self.sortkey())
335
336 def sortkey(self):
337 return self.name,
338
339 def issymbol(self):
340 return True
341
342 def __eq__(self, other):
343 return self.sortkey() == other.sortkey()
344
345 def asdummy(self):
346 return Dummy(self.name)
347
348 @classmethod
349 def _fromast(cls, node):
350 if isinstance(node, ast.Module) and len(node.body) == 1:
351 return cls._fromast(node.body[0])
352 elif isinstance(node, ast.Expr):
353 return cls._fromast(node.value)
354 elif isinstance(node, ast.Name):
355 return Symbol(node.id)
356 raise SyntaxError('invalid syntax')
357
358 def __repr__(self):
359 return self.name
360
361 def _repr_latex_(self):
362 return '$${}$$'.format(self.name)
363
364 @classmethod
365 def fromsympy(cls, expr):
366 import sympy
367 if isinstance(expr, sympy.Dummy):
368 return Dummy(expr.name)
369 elif isinstance(expr, sympy.Symbol):
370 return Symbol(expr.name)
371 else:
372 raise TypeError('expr must be a sympy.Symbol instance')
373
374
375 class Dummy(Symbol):
376
377 _count = 0
378
379 def __new__(cls, name=None):
380 if name is None:
381 name = 'Dummy_{}'.format(Dummy._count)
382 elif not isinstance(name, str):
383 raise TypeError('name must be a string')
384 self = object().__new__(cls)
385 self._index = Dummy._count
386 self._name = name.strip()
387 self._coefficients = {self: Fraction(1)}
388 self._constant = Fraction(0)
389 self._symbols = (self,)
390 self._dimension = 1
391 Dummy._count += 1
392 return self
393
394 def __hash__(self):
395 return hash(self.sortkey())
396
397 def sortkey(self):
398 return self._name, self._index
399
400 def __repr__(self):
401 return '_{}'.format(self.name)
402
403 def _repr_latex_(self):
404 return '$${}_{{{}}}$$'.format(self.name, self._index)
405
406
407 def symbols(names):
408 if isinstance(names, str):
409 names = names.replace(',', ' ').split()
410 return tuple(Symbol(name) for name in names)
411
412
413 class Rational(Expression, Fraction):
414
415 def __new__(cls, numerator=0, denominator=None):
416 self = object().__new__(cls)
417 self._coefficients = {}
418 self._constant = Fraction(numerator, denominator)
419 self._symbols = ()
420 self._dimension = 0
421 self._numerator = self._constant.numerator
422 self._denominator = self._constant.denominator
423 return self
424
425 def __hash__(self):
426 return Fraction.__hash__(self)
427
428 @property
429 def constant(self):
430 return self
431
432 def isconstant(self):
433 return True
434
435 def __bool__(self):
436 return Fraction.__bool__(self)
437
438 def __repr__(self):
439 if self.denominator == 1:
440 return '{!r}'.format(self.numerator)
441 else:
442 return '{!r}/{!r}'.format(self.numerator, self.denominator)
443
444 def _repr_latex_(self):
445 if self.denominator == 1:
446 return '$${}$$'.format(self.numerator)
447 elif self.numerator < 0:
448 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
449 self.denominator)
450 else:
451 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
452 self.denominator)
453
454 @classmethod
455 def fromsympy(cls, expr):
456 import sympy
457 if isinstance(expr, sympy.Rational):
458 return Rational(expr.p, expr.q)
459 elif isinstance(expr, numbers.Rational):
460 return Rational(expr)
461 else:
462 raise TypeError('expr must be a sympy.Rational instance')