Improve Expression.fromstring(), Domain.fromstring()
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from fractions import Fraction, gcd
7
8
9 __all__ = [
10 'Expression',
11 'Symbol', 'symbols',
12 'Constant',
13 ]
14
15
16 def _polymorphic(func):
17 @functools.wraps(func)
18 def wrapper(left, right):
19 if isinstance(right, Expression):
20 return func(left, right)
21 elif isinstance(right, numbers.Rational):
22 right = Constant(right)
23 return func(left, right)
24 return NotImplemented
25 return wrapper
26
27
28 class Expression:
29 """
30 This class implements linear expressions.
31 """
32
33 __slots__ = (
34 '_coefficients',
35 '_constant',
36 '_symbols',
37 '_dimension',
38 )
39
40 def __new__(cls, coefficients=None, constant=0):
41 if isinstance(coefficients, str):
42 if constant:
43 raise TypeError('too many arguments')
44 return cls.fromstring(coefficients)
45 if isinstance(coefficients, dict):
46 coefficients = coefficients.items()
47 if coefficients is None:
48 return Constant(constant)
49 coefficients = [(symbol, coefficient)
50 for symbol, coefficient in coefficients if coefficient != 0]
51 if len(coefficients) == 0:
52 return Constant(constant)
53 elif len(coefficients) == 1 and constant == 0:
54 symbol, coefficient = coefficients[0]
55 if coefficient == 1:
56 return Symbol(symbol)
57 self = object().__new__(cls)
58 self._coefficients = {}
59 for symbol, coefficient in coefficients:
60 if isinstance(symbol, Symbol):
61 symbol = symbol.name
62 elif not isinstance(symbol, str):
63 raise TypeError('symbols must be strings or Symbol instances')
64 if isinstance(coefficient, Constant):
65 coefficient = coefficient.constant
66 if not isinstance(coefficient, numbers.Rational):
67 raise TypeError('coefficients must be rational numbers '
68 'or Constant instances')
69 self._coefficients[symbol] = coefficient
70 if isinstance(constant, Constant):
71 constant = constant.constant
72 if not isinstance(constant, numbers.Rational):
73 raise TypeError('constant must be a rational number '
74 'or a Constant instance')
75 self._constant = constant
76 self._symbols = tuple(sorted(self._coefficients))
77 self._dimension = len(self._symbols)
78 return self
79
80 def coefficient(self, symbol):
81 if isinstance(symbol, Symbol):
82 symbol = str(symbol)
83 elif not isinstance(symbol, str):
84 raise TypeError('symbol must be a string or a Symbol instance')
85 try:
86 return self._coefficients[symbol]
87 except KeyError:
88 return 0
89
90 __getitem__ = coefficient
91
92 def coefficients(self):
93 for symbol in self.symbols:
94 yield symbol, self.coefficient(symbol)
95
96 @property
97 def constant(self):
98 return self._constant
99
100 @property
101 def symbols(self):
102 return self._symbols
103
104 @property
105 def dimension(self):
106 return self._dimension
107
108 def isconstant(self):
109 return False
110
111 def issymbol(self):
112 return False
113
114 def values(self):
115 for symbol in self.symbols:
116 yield self.coefficient(symbol)
117 yield self.constant
118
119 def __bool__(self):
120 return True
121
122 def __pos__(self):
123 return self
124
125 def __neg__(self):
126 return self * -1
127
128 @_polymorphic
129 def __add__(self, other):
130 coefficients = dict(self.coefficients())
131 for symbol, coefficient in other.coefficients():
132 if symbol in coefficients:
133 coefficients[symbol] += coefficient
134 else:
135 coefficients[symbol] = coefficient
136 constant = self.constant + other.constant
137 return Expression(coefficients, constant)
138
139 __radd__ = __add__
140
141 @_polymorphic
142 def __sub__(self, other):
143 coefficients = dict(self.coefficients())
144 for symbol, coefficient in other.coefficients():
145 if symbol in coefficients:
146 coefficients[symbol] -= coefficient
147 else:
148 coefficients[symbol] = -coefficient
149 constant = self.constant - other.constant
150 return Expression(coefficients, constant)
151
152 def __rsub__(self, other):
153 return -(self - other)
154
155 @_polymorphic
156 def __mul__(self, other):
157 if other.isconstant():
158 coefficients = dict(self.coefficients())
159 for symbol in coefficients:
160 coefficients[symbol] *= other.constant
161 constant = self.constant * other.constant
162 return Expression(coefficients, constant)
163 if isinstance(other, Expression) and not self.isconstant():
164 raise ValueError('non-linear expression: '
165 '{} * {}'.format(self._parenstr(), other._parenstr()))
166 return NotImplemented
167
168 __rmul__ = __mul__
169
170 @_polymorphic
171 def __truediv__(self, other):
172 if other.isconstant():
173 coefficients = dict(self.coefficients())
174 for symbol in coefficients:
175 coefficients[symbol] = \
176 Fraction(coefficients[symbol], other.constant)
177 constant = Fraction(self.constant, other.constant)
178 return Expression(coefficients, constant)
179 if isinstance(other, Expression):
180 raise ValueError('non-linear expression: '
181 '{} / {}'.format(self._parenstr(), other._parenstr()))
182 return NotImplemented
183
184 def __rtruediv__(self, other):
185 if isinstance(other, self):
186 if self.isconstant():
187 constant = Fraction(other, self.constant)
188 return Expression(constant=constant)
189 else:
190 raise ValueError('non-linear expression: '
191 '{} / {}'.format(other._parenstr(), self._parenstr()))
192 return NotImplemented
193
194 @_polymorphic
195 def __eq__(self, other):
196 # "normal" equality
197 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
198 return isinstance(other, Expression) and \
199 self._coefficients == other._coefficients and \
200 self.constant == other.constant
201
202 @_polymorphic
203 def __le__(self, other):
204 from .polyhedra import Le
205 return Le(self, other)
206
207 @_polymorphic
208 def __lt__(self, other):
209 from .polyhedra import Lt
210 return Lt(self, other)
211
212 @_polymorphic
213 def __ge__(self, other):
214 from .polyhedra import Ge
215 return Ge(self, other)
216
217 @_polymorphic
218 def __gt__(self, other):
219 from .polyhedra import Gt
220 return Gt(self, other)
221
222 def __hash__(self):
223 return hash((tuple(sorted(self._coefficients.items())), self._constant))
224
225 def _toint(self):
226 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
227 [value.denominator for value in self.values()])
228 return self * lcm
229
230 @classmethod
231 def _fromast(cls, node):
232 if isinstance(node, ast.Module) and len(node.body) == 1:
233 return cls._fromast(node.body[0])
234 elif isinstance(node, ast.Expr):
235 return cls._fromast(node.value)
236 elif isinstance(node, ast.Name):
237 return Symbol(node.id)
238 elif isinstance(node, ast.Num):
239 return Constant(node.n)
240 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
241 return -cls._fromast(node.operand)
242 elif isinstance(node, ast.BinOp):
243 left = cls._fromast(node.left)
244 right = cls._fromast(node.right)
245 if isinstance(node.op, ast.Add):
246 return left + right
247 elif isinstance(node.op, ast.Sub):
248 return left - right
249 elif isinstance(node.op, ast.Mult):
250 return left * right
251 elif isinstance(node.op, ast.Div):
252 return left / right
253 raise SyntaxError('invalid syntax')
254
255 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
256
257 @classmethod
258 def fromstring(cls, string):
259 # add implicit multiplication operators, e.g. '5x' -> '5*x'
260 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
261 tree = ast.parse(string, 'eval')
262 return cls._fromast(tree)
263
264 def __str__(self):
265 string = ''
266 i = 0
267 for symbol in self.symbols:
268 coefficient = self.coefficient(symbol)
269 if coefficient == 1:
270 if i == 0:
271 string += symbol
272 else:
273 string += ' + {}'.format(symbol)
274 elif coefficient == -1:
275 if i == 0:
276 string += '-{}'.format(symbol)
277 else:
278 string += ' - {}'.format(symbol)
279 else:
280 if i == 0:
281 string += '{}*{}'.format(coefficient, symbol)
282 elif coefficient > 0:
283 string += ' + {}*{}'.format(coefficient, symbol)
284 else:
285 assert coefficient < 0
286 coefficient *= -1
287 string += ' - {}*{}'.format(coefficient, symbol)
288 i += 1
289 constant = self.constant
290 if constant != 0 and i == 0:
291 string += '{}'.format(constant)
292 elif constant > 0:
293 string += ' + {}'.format(constant)
294 elif constant < 0:
295 constant *= -1
296 string += ' - {}'.format(constant)
297 if string == '':
298 string = '0'
299 return string
300
301 def _parenstr(self, always=False):
302 string = str(self)
303 if not always and (self.isconstant() or self.issymbol()):
304 return string
305 else:
306 return '({})'.format(string)
307
308 def __repr__(self):
309 return '{}({!r})'.format(self.__class__.__name__, str(self))
310
311 @classmethod
312 def fromsympy(cls, expr):
313 import sympy
314 coefficients = {}
315 constant = 0
316 for symbol, coefficient in expr.as_coefficients_dict().items():
317 coefficient = Fraction(coefficient.p, coefficient.q)
318 if symbol == sympy.S.One:
319 constant = coefficient
320 elif isinstance(symbol, sympy.Symbol):
321 symbol = symbol.name
322 coefficients[symbol] = coefficient
323 else:
324 raise ValueError('non-linear expression: {!r}'.format(expr))
325 return cls(coefficients, constant)
326
327 def tosympy(self):
328 import sympy
329 expr = 0
330 for symbol, coefficient in self.coefficients():
331 term = coefficient * sympy.Symbol(symbol)
332 expr += term
333 expr += self.constant
334 return expr
335
336
337 class Symbol(Expression):
338
339 __slots__ = Expression.__slots__ + (
340 '_name',
341 )
342
343 def __new__(cls, name):
344 if isinstance(name, Symbol):
345 name = name.name
346 elif not isinstance(name, str):
347 raise TypeError('name must be a string or a Symbol instance')
348 name = name.strip()
349 self = object().__new__(cls)
350 self._coefficients = {name: 1}
351 self._constant = 0
352 self._symbols = tuple(name)
353 self._name = name
354 self._dimension = 1
355 return self
356
357 @property
358 def name(self):
359 return self._name
360
361 def issymbol(self):
362 return True
363
364 @classmethod
365 def _fromast(cls, node):
366 if isinstance(node, ast.Module) and len(node.body) == 1:
367 return cls._fromast(node.body[0])
368 elif isinstance(node, ast.Expr):
369 return cls._fromast(node.value)
370 elif isinstance(node, ast.Name):
371 return Symbol(node.id)
372 raise SyntaxError('invalid syntax')
373
374 def __repr__(self):
375 return '{}({!r})'.format(self.__class__.__name__, self._name)
376
377 @classmethod
378 def fromsympy(cls, expr):
379 import sympy
380 if isinstance(expr, sympy.Symbol):
381 return cls(expr.name)
382 else:
383 raise TypeError('expr must be a sympy.Symbol instance')
384
385
386 def symbols(names):
387 if isinstance(names, str):
388 names = names.replace(',', ' ').split()
389 return (Symbol(name) for name in names)
390
391
392 class Constant(Expression):
393
394 def __new__(cls, numerator=0, denominator=None):
395 self = object().__new__(cls)
396 if denominator is None and isinstance(numerator, Constant):
397 self._constant = numerator.constant
398 else:
399 self._constant = Fraction(numerator, denominator)
400 self._coefficients = {}
401 self._symbols = ()
402 self._dimension = 0
403 return self
404
405 def isconstant(self):
406 return True
407
408 def __bool__(self):
409 return self.constant != 0
410
411 @classmethod
412 def fromstring(cls, string):
413 if isinstance(string, str):
414 return Constant(Fraction(string))
415 else:
416 raise TypeError('string must be a string instance')
417
418 def __repr__(self):
419 if self.constant.denominator == 1:
420 return '{}({!r})'.format(self.__class__.__name__,
421 self.constant.numerator)
422 else:
423 return '{}({!r}, {!r})'.format(self.__class__.__name__,
424 self.constant.numerator, self.constant.denominator)
425
426 @classmethod
427 def fromsympy(cls, expr):
428 import sympy
429 if isinstance(expr, sympy.Rational):
430 return cls(expr.p, expr.q)
431 elif isinstance(expr, numbers.Rational):
432 return cls(expr)
433 else:
434 raise TypeError('expr must be a sympy.Rational instance')