Add _repr_latex_ methods for IPython prettyprint
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict, defaultdict, Mapping
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'Dummy', 'symbols',
13 'Rational',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Rational(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 def __new__(cls, coefficients=None, constant=0):
35 if isinstance(coefficients, str):
36 if constant != 0:
37 raise TypeError('too many arguments')
38 return Expression.fromstring(coefficients)
39 if coefficients is None:
40 return Rational(constant)
41 if isinstance(coefficients, Mapping):
42 coefficients = coefficients.items()
43 for symbol, coefficient in coefficients:
44 if not isinstance(symbol, Symbol):
45 raise TypeError('symbols must be Symbol instances')
46 if not isinstance(coefficient, numbers.Rational):
47 raise TypeError('coefficients must be rational numbers')
48 coefficients = [(symbol, Fraction(coefficient))
49 for symbol, coefficient in coefficients if coefficient != 0]
50 if not isinstance(constant, numbers.Rational):
51 raise TypeError('constant must be a rational number')
52 constant = Fraction(constant)
53 if len(coefficients) == 0:
54 return Rational(constant)
55 if len(coefficients) == 1 and constant == 0:
56 symbol, coefficient = coefficients[0]
57 if coefficient == 1:
58 return symbol
59 self = object().__new__(cls)
60 self._coefficients = OrderedDict(sorted(coefficients,
61 key=lambda item: item[0].sortkey()))
62 self._constant = constant
63 self._symbols = tuple(self._coefficients)
64 self._dimension = len(self._symbols)
65 return self
66
67 def coefficient(self, symbol):
68 if not isinstance(symbol, Symbol):
69 raise TypeError('symbol must be a Symbol instance')
70 try:
71 return Rational(self._coefficients[symbol])
72 except KeyError:
73 return Rational(0)
74
75 __getitem__ = coefficient
76
77 def coefficients(self):
78 for symbol, coefficient in self._coefficients.items():
79 yield symbol, Rational(coefficient)
80
81 @property
82 def constant(self):
83 return Rational(self._constant)
84
85 @property
86 def symbols(self):
87 return self._symbols
88
89 @property
90 def dimension(self):
91 return self._dimension
92
93 def __hash__(self):
94 return hash((tuple(self._coefficients.items()), self._constant))
95
96 def isconstant(self):
97 return False
98
99 def issymbol(self):
100 return False
101
102 def values(self):
103 for coefficient in self._coefficients.values():
104 yield Rational(coefficient)
105 yield Rational(self._constant)
106
107 def __bool__(self):
108 return True
109
110 def __pos__(self):
111 return self
112
113 def __neg__(self):
114 return self * -1
115
116 @_polymorphic
117 def __add__(self, other):
118 coefficients = defaultdict(Fraction, self._coefficients)
119 for symbol, coefficient in other._coefficients.items():
120 coefficients[symbol] += coefficient
121 constant = self._constant + other._constant
122 return Expression(coefficients, constant)
123
124 __radd__ = __add__
125
126 @_polymorphic
127 def __sub__(self, other):
128 coefficients = defaultdict(Fraction, self._coefficients)
129 for symbol, coefficient in other._coefficients.items():
130 coefficients[symbol] -= coefficient
131 constant = self._constant - other._constant
132 return Expression(coefficients, constant)
133
134 def __rsub__(self, other):
135 return -(self - other)
136
137 @_polymorphic
138 def __mul__(self, other):
139 if isinstance(other, Rational):
140 return other.__rmul__(self)
141 return NotImplemented
142
143 __rmul__ = __mul__
144
145 @_polymorphic
146 def __truediv__(self, other):
147 if isinstance(other, Rational):
148 return other.__rtruediv__(self)
149 return NotImplemented
150
151 __rtruediv__ = __truediv__
152
153 @_polymorphic
154 def __eq__(self, other):
155 # "normal" equality
156 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
157 return isinstance(other, Expression) and \
158 self._coefficients == other._coefficients and \
159 self._constant == other._constant
160
161 @_polymorphic
162 def __le__(self, other):
163 from .polyhedra import Le
164 return Le(self, other)
165
166 @_polymorphic
167 def __lt__(self, other):
168 from .polyhedra import Lt
169 return Lt(self, other)
170
171 @_polymorphic
172 def __ge__(self, other):
173 from .polyhedra import Ge
174 return Ge(self, other)
175
176 @_polymorphic
177 def __gt__(self, other):
178 from .polyhedra import Gt
179 return Gt(self, other)
180
181 def scaleint(self):
182 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
183 [value.denominator for value in self.values()])
184 return self * lcm
185
186 def subs(self, symbol, expression=None):
187 if expression is None:
188 if isinstance(symbol, Mapping):
189 symbol = symbol.items()
190 substitutions = symbol
191 else:
192 substitutions = [(symbol, expression)]
193 result = self
194 for symbol, expression in substitutions:
195 if not isinstance(symbol, Symbol):
196 raise TypeError('symbols must be Symbol instances')
197 coefficients = [(othersymbol, coefficient)
198 for othersymbol, coefficient in result._coefficients.items()
199 if othersymbol != symbol]
200 coefficient = result._coefficients.get(symbol, 0)
201 constant = result._constant
202 result = Expression(coefficients, constant) + coefficient*expression
203 return result
204
205 @classmethod
206 def _fromast(cls, node):
207 if isinstance(node, ast.Module) and len(node.body) == 1:
208 return cls._fromast(node.body[0])
209 elif isinstance(node, ast.Expr):
210 return cls._fromast(node.value)
211 elif isinstance(node, ast.Name):
212 return Symbol(node.id)
213 elif isinstance(node, ast.Num):
214 return Rational(node.n)
215 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
216 return -cls._fromast(node.operand)
217 elif isinstance(node, ast.BinOp):
218 left = cls._fromast(node.left)
219 right = cls._fromast(node.right)
220 if isinstance(node.op, ast.Add):
221 return left + right
222 elif isinstance(node.op, ast.Sub):
223 return left - right
224 elif isinstance(node.op, ast.Mult):
225 return left * right
226 elif isinstance(node.op, ast.Div):
227 return left / right
228 raise SyntaxError('invalid syntax')
229
230 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
231
232 @classmethod
233 def fromstring(cls, string):
234 # add implicit multiplication operators, e.g. '5x' -> '5*x'
235 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
236 tree = ast.parse(string, 'eval')
237 return cls._fromast(tree)
238
239 def __repr__(self):
240 string = ''
241 for i, (symbol, coefficient) in enumerate(self.coefficients()):
242 if coefficient == 1:
243 if i != 0:
244 string += ' + '
245 elif coefficient == -1:
246 string += '-' if i == 0 else ' - '
247 elif i == 0:
248 string += '{}*'.format(coefficient)
249 elif coefficient > 0:
250 string += ' + {}*'.format(coefficient)
251 else:
252 string += ' - {}*'.format(-coefficient)
253 string += '{}'.format(symbol)
254 constant = self.constant
255 if len(string) == 0:
256 string += '{}'.format(constant)
257 elif constant > 0:
258 string += ' + {}'.format(constant)
259 elif constant < 0:
260 string += ' - {}'.format(-constant)
261 return string
262
263 def _repr_latex_(self):
264 string = ''
265 for i, (symbol, coefficient) in enumerate(self.coefficients()):
266 if coefficient == 1:
267 if i != 0:
268 string += ' + '
269 elif coefficient == -1:
270 string += '-' if i == 0 else ' - '
271 elif i == 0:
272 string += '{}'.format(coefficient._repr_latex_().strip('$'))
273 elif coefficient > 0:
274 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
275 elif coefficient < 0:
276 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
277 string += '{}'.format(symbol._repr_latex_().strip('$'))
278 constant = self.constant
279 if len(string) == 0:
280 string += '{}'.format(constant._repr_latex_().strip('$'))
281 elif constant > 0:
282 string += ' + {}'.format(constant._repr_latex_().strip('$'))
283 elif constant < 0:
284 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
285 return '${}$'.format(string)
286
287 def _parenstr(self, always=False):
288 string = str(self)
289 if not always and (self.isconstant() or self.issymbol()):
290 return string
291 else:
292 return '({})'.format(string)
293
294 @classmethod
295 def fromsympy(cls, expr):
296 import sympy
297 coefficients = []
298 constant = 0
299 for symbol, coefficient in expr.as_coefficients_dict().items():
300 coefficient = Fraction(coefficient.p, coefficient.q)
301 if symbol == sympy.S.One:
302 constant = coefficient
303 elif isinstance(symbol, sympy.Symbol):
304 symbol = Symbol(symbol.name)
305 coefficients.append((symbol, coefficient))
306 else:
307 raise ValueError('non-linear expression: {!r}'.format(expr))
308 return Expression(coefficients, constant)
309
310 def tosympy(self):
311 import sympy
312 expr = 0
313 for symbol, coefficient in self.coefficients():
314 term = coefficient * sympy.Symbol(symbol.name)
315 expr += term
316 expr += self.constant
317 return expr
318
319
320 class Symbol(Expression):
321
322 def __new__(cls, name):
323 if not isinstance(name, str):
324 raise TypeError('name must be a string')
325 self = object().__new__(cls)
326 self._name = name.strip()
327 self._coefficients = {self: 1}
328 self._constant = 0
329 self._symbols = (self,)
330 self._dimension = 1
331 return self
332
333 @property
334 def name(self):
335 return self._name
336
337 def __hash__(self):
338 return hash(self.sortkey())
339
340 def sortkey(self):
341 return self.name,
342
343 def issymbol(self):
344 return True
345
346 def __eq__(self, other):
347 return not isinstance(other, Dummy) and isinstance(other, Symbol) \
348 and self.name == other.name
349
350 def asdummy(self):
351 return Dummy(self.name)
352
353 @classmethod
354 def _fromast(cls, node):
355 if isinstance(node, ast.Module) and len(node.body) == 1:
356 return cls._fromast(node.body[0])
357 elif isinstance(node, ast.Expr):
358 return cls._fromast(node.value)
359 elif isinstance(node, ast.Name):
360 return Symbol(node.id)
361 raise SyntaxError('invalid syntax')
362
363 def __repr__(self):
364 return self.name
365
366 def _repr_latex_(self):
367 return '${}$'.format(self.name)
368
369 @classmethod
370 def fromsympy(cls, expr):
371 import sympy
372 if isinstance(expr, sympy.Symbol):
373 return cls(expr.name)
374 else:
375 raise TypeError('expr must be a sympy.Symbol instance')
376
377
378 class Dummy(Symbol):
379
380 _count = 0
381
382 def __new__(cls, name=None):
383 if name is None:
384 name = 'Dummy_{}'.format(Dummy._count)
385 self = object().__new__(cls)
386 self._index = Dummy._count
387 self._name = name.strip()
388 self._coefficients = {self: 1}
389 self._constant = 0
390 self._symbols = (self,)
391 self._dimension = 1
392 Dummy._count += 1
393 return self
394
395 def __hash__(self):
396 return hash(self.sortkey())
397
398 def sortkey(self):
399 return self._name, self._index
400
401 def __eq__(self, other):
402 return isinstance(other, Dummy) and self._index == other._index
403
404 def __repr__(self):
405 return '_{}'.format(self.name)
406
407 def _repr_latex_(self):
408 return '${}_{{{}}}$'.format(self.name, self._index)
409
410
411 def symbols(names):
412 if isinstance(names, str):
413 names = names.replace(',', ' ').split()
414 return tuple(Symbol(name) for name in names)
415
416
417 class Rational(Expression, Fraction):
418
419 def __new__(cls, numerator=0, denominator=None):
420 self = Fraction.__new__(cls, numerator, denominator)
421 self._coefficients = {}
422 self._constant = Fraction(self)
423 self._symbols = ()
424 self._dimension = 0
425 return self
426
427 def __hash__(self):
428 return Fraction.__hash__(self)
429
430 @property
431 def constant(self):
432 return self
433
434 def isconstant(self):
435 return True
436
437 def __bool__(self):
438 return Fraction.__bool__(self)
439
440 @_polymorphic
441 def __mul__(self, other):
442 coefficients = dict(other._coefficients)
443 for symbol in coefficients:
444 coefficients[symbol] *= self._constant
445 constant = other._constant * self._constant
446 return Expression(coefficients, constant)
447
448 __rmul__ = __mul__
449
450 @_polymorphic
451 def __rtruediv__(self, other):
452 coefficients = dict(other._coefficients)
453 for symbol in coefficients:
454 coefficients[symbol] /= self._constant
455 constant = other._constant / self._constant
456 return Expression(coefficients, constant)
457
458 @classmethod
459 def fromstring(cls, string):
460 if not isinstance(string, str):
461 raise TypeError('string must be a string instance')
462 return Rational(string)
463
464 def __repr__(self):
465 if self.denominator == 1:
466 return '{!r}'.format(self.numerator)
467 else:
468 return '{!r}/{!r}'.format(self.numerator, self.denominator)
469
470 def _repr_latex_(self):
471 if self.denominator == 1:
472 return '${}$'.format(self.numerator)
473 elif self.numerator < 0:
474 return '$-\\frac{{{}}}{{{}}}$'.format(-self.numerator,
475 self.denominator)
476 else:
477 return '$\\frac{{{}}}{{{}}}$'.format(self.numerator,
478 self.denominator)
479
480 @classmethod
481 def fromsympy(cls, expr):
482 import sympy
483 if isinstance(expr, sympy.Rational):
484 return Rational(expr.p, expr.q)
485 elif isinstance(expr, numbers.Rational):
486 return Rational(expr)
487 else:
488 raise TypeError('expr must be a sympy.Rational instance')