Implement Expression.fromstring
[linpy.git] / pypol / linear.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from fractions import Fraction, gcd
7
8 from pypol import isl
9 from pypol.isl import libisl
10
11
12 __all__ = [
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'Empty', 'Universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = Constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = Constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 _main_ctx = isl.Context()
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 def __new__(cls, coefficients=None, constant=0):
54 if isinstance(coefficients, str):
55 if constant:
56 raise TypeError('too many arguments')
57 return cls.fromstring(coefficients)
58 if isinstance(coefficients, dict):
59 coefficients = coefficients.items()
60 if coefficients is None:
61 return Constant(constant)
62 coefficients = [(symbol, coefficient)
63 for symbol, coefficient in coefficients if coefficient != 0]
64 if len(coefficients) == 0:
65 return Constant(constant)
66 elif len(coefficients) == 1 and constant == 0:
67 symbol, coefficient = coefficients[0]
68 if coefficient == 1:
69 return Symbol(symbol)
70 self = object().__new__(cls)
71 self._coefficients = {}
72 for symbol, coefficient in coefficients:
73 if isinstance(symbol, Symbol):
74 symbol = symbol.name
75 elif not isinstance(symbol, str):
76 raise TypeError('symbols must be strings or Symbol instances')
77 if isinstance(coefficient, Constant):
78 coefficient = coefficient.constant
79 if not isinstance(coefficient, numbers.Rational):
80 raise TypeError('coefficients must be rational numbers or Constant instances')
81 self._coefficients[symbol] = coefficient
82 if isinstance(constant, Constant):
83 constant = constant.constant
84 if not isinstance(constant, numbers.Rational):
85 raise TypeError('constant must be a rational number or a Constant instance')
86 self._constant = constant
87 self._symbols = tuple(sorted(self._coefficients))
88 self._dimension = len(self._symbols)
89 return self
90
91 @classmethod
92 def _fromast(cls, node):
93 if isinstance(node, ast.Module):
94 assert len(node.body) == 1
95 return cls._fromast(node.body[0])
96 elif isinstance(node, ast.Expr):
97 return cls._fromast(node.value)
98 elif isinstance(node, ast.Name):
99 return Symbol(node.id)
100 elif isinstance(node, ast.Num):
101 return Constant(node.n)
102 elif isinstance(node, ast.UnaryOp):
103 if isinstance(node.op, ast.USub):
104 return -cls._fromast(node.operand)
105 elif isinstance(node, ast.BinOp):
106 left = cls._fromast(node.left)
107 right = cls._fromast(node.right)
108 if isinstance(node.op, ast.Add):
109 return left + right
110 elif isinstance(node.op, ast.Sub):
111 return left - right
112 elif isinstance(node.op, ast.Mult):
113 return left * right
114 elif isinstance(node.op, ast.Div):
115 return left / right
116 raise SyntaxError('invalid syntax')
117
118 @classmethod
119 def fromstring(cls, string):
120 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()',
121 lambda m: '{}*{}'.format(m.group(1), m.group(2)),
122 string)
123 tree = ast.parse(string, 'eval')
124 return cls._fromast(tree)
125
126 @property
127 def symbols(self):
128 return self._symbols
129
130 @property
131 def dimension(self):
132 return self._dimension
133
134 def coefficient(self, symbol):
135 if isinstance(symbol, Symbol):
136 symbol = str(symbol)
137 elif not isinstance(symbol, str):
138 raise TypeError('symbol must be a string or a Symbol instance')
139 try:
140 return self._coefficients[symbol]
141 except KeyError:
142 return 0
143
144 __getitem__ = coefficient
145
146 def coefficients(self):
147 for symbol in self.symbols:
148 yield symbol, self.coefficient(symbol)
149
150 @property
151 def constant(self):
152 return self._constant
153
154 def isconstant(self):
155 return False
156
157 def values(self):
158 for symbol in self.symbols:
159 yield self.coefficient(symbol)
160 yield self.constant
161
162 def issymbol(self):
163 return False
164
165 def __bool__(self):
166 return True
167
168 def __pos__(self):
169 return self
170
171 def __neg__(self):
172 return self * -1
173
174 @_polymorphic_method
175 def __add__(self, other):
176 coefficients = dict(self.coefficients())
177 for symbol, coefficient in other.coefficients():
178 if symbol in coefficients:
179 coefficients[symbol] += coefficient
180 else:
181 coefficients[symbol] = coefficient
182 constant = self.constant + other.constant
183 return Expression(coefficients, constant)
184
185 __radd__ = __add__
186
187 @_polymorphic_method
188 def __sub__(self, other):
189 coefficients = dict(self.coefficients())
190 for symbol, coefficient in other.coefficients():
191 if symbol in coefficients:
192 coefficients[symbol] -= coefficient
193 else:
194 coefficients[symbol] = -coefficient
195 constant = self.constant - other.constant
196 return Expression(coefficients, constant)
197
198 def __rsub__(self, other):
199 return -(self - other)
200
201 @_polymorphic_method
202 def __mul__(self, other):
203 if other.isconstant():
204 coefficients = dict(self.coefficients())
205 for symbol in coefficients:
206 coefficients[symbol] *= other.constant
207 constant = self.constant * other.constant
208 return Expression(coefficients, constant)
209 if isinstance(other, Expression) and not self.isconstant():
210 raise ValueError('non-linear expression: '
211 '{} * {}'.format(self._parenstr(), other._parenstr()))
212 return NotImplemented
213
214 __rmul__ = __mul__
215
216 @_polymorphic_method
217 def __truediv__(self, other):
218 if other.isconstant():
219 coefficients = dict(self.coefficients())
220 for symbol in coefficients:
221 coefficients[symbol] = \
222 Fraction(coefficients[symbol], other.constant)
223 constant = Fraction(self.constant, other.constant)
224 return Expression(coefficients, constant)
225 if isinstance(other, Expression):
226 raise ValueError('non-linear expression: '
227 '{} / {}'.format(self._parenstr(), other._parenstr()))
228 return NotImplemented
229
230 def __rtruediv__(self, other):
231 if isinstance(other, self):
232 if self.isconstant():
233 constant = Fraction(other, self.constant)
234 return Expression(constant=constant)
235 else:
236 raise ValueError('non-linear expression: '
237 '{} / {}'.format(other._parenstr(), self._parenstr()))
238 return NotImplemented
239
240 def __str__(self):
241 string = ''
242 i = 0
243 for symbol in self.symbols:
244 coefficient = self.coefficient(symbol)
245 if coefficient == 1:
246 if i == 0:
247 string += symbol
248 else:
249 string += ' + {}'.format(symbol)
250 elif coefficient == -1:
251 if i == 0:
252 string += '-{}'.format(symbol)
253 else:
254 string += ' - {}'.format(symbol)
255 else:
256 if i == 0:
257 string += '{}*{}'.format(coefficient, symbol)
258 elif coefficient > 0:
259 string += ' + {}*{}'.format(coefficient, symbol)
260 else:
261 assert coefficient < 0
262 coefficient *= -1
263 string += ' - {}*{}'.format(coefficient, symbol)
264 i += 1
265 constant = self.constant
266 if constant != 0 and i == 0:
267 string += '{}'.format(constant)
268 elif constant > 0:
269 string += ' + {}'.format(constant)
270 elif constant < 0:
271 constant *= -1
272 string += ' - {}'.format(constant)
273 if string == '':
274 string = '0'
275 return string
276
277 def _parenstr(self, always=False):
278 string = str(self)
279 if not always and (self.isconstant() or self.issymbol()):
280 return string
281 else:
282 return '({})'.format(string)
283
284 def __repr__(self):
285 string = '{}({{'.format(self.__class__.__name__)
286 for i, (symbol, coefficient) in enumerate(self.coefficients()):
287 if i != 0:
288 string += ', '
289 string += '{!r}: {!r}'.format(symbol, coefficient)
290 string += '}}, {!r})'.format(self.constant)
291 return string
292
293 @_polymorphic_method
294 def __eq__(self, other):
295 # "normal" equality
296 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
297 return isinstance(other, Expression) and \
298 self._coefficients == other._coefficients and \
299 self.constant == other.constant
300
301 def __hash__(self):
302 return hash((tuple(sorted(self._coefficients.items())), self._constant))
303
304 def _toint(self):
305 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
306 [value.denominator for value in self.values()])
307 return self * lcm
308
309 @_polymorphic_method
310 def _eq(self, other):
311 return Polyhedron(equalities=[(self - other)._toint()])
312
313 @_polymorphic_method
314 def __le__(self, other):
315 return Polyhedron(inequalities=[(other - self)._toint()])
316
317 @_polymorphic_method
318 def __lt__(self, other):
319 return Polyhedron(inequalities=[(other - self)._toint() - 1])
320
321 @_polymorphic_method
322 def __ge__(self, other):
323 return Polyhedron(inequalities=[(self - other)._toint()])
324
325 @_polymorphic_method
326 def __gt__(self, other):
327 return Polyhedron(inequalities=[(self - other)._toint() - 1])
328
329
330 class Constant(Expression):
331
332 def __new__(cls, numerator=0, denominator=None):
333 self = object().__new__(cls)
334 if denominator is None:
335 if isinstance(numerator, numbers.Rational):
336 self._constant = numerator
337 elif isinstance(numerator, Constant):
338 self._constant = numerator.constant
339 else:
340 raise TypeError('constant must be a rational number or a Constant instance')
341 else:
342 self._constant = Fraction(numerator, denominator)
343 self._coefficients = {}
344 self._symbols = ()
345 self._dimension = 0
346 return self
347
348 def isconstant(self):
349 return True
350
351 def __bool__(self):
352 return bool(self.constant)
353
354 def __repr__(self):
355 return '{}({!r})'.format(self.__class__.__name__, self._constant)
356
357
358 class Symbol(Expression):
359
360 def __new__(cls, name):
361 if isinstance(name, Symbol):
362 name = name.name
363 elif not isinstance(name, str):
364 raise TypeError('name must be a string or a Symbol instance')
365 self = object().__new__(cls)
366 self._coefficients = {name: 1}
367 self._constant = 0
368 self._symbols = tuple(name)
369 self._name = name
370 self._dimension = 1
371 return self
372
373 @property
374 def name(self):
375 return self._name
376
377 def issymbol(self):
378 return True
379
380 def __repr__(self):
381 return '{}({!r})'.format(self.__class__.__name__, self._name)
382
383 def symbols(names):
384 if isinstance(names, str):
385 names = names.replace(',', ' ').split()
386 return (Symbol(name) for name in names)
387
388
389 @_polymorphic_operator
390 def eq(a, b):
391 return a.__eq__(b)
392
393 @_polymorphic_operator
394 def le(a, b):
395 return a.__le__(b)
396
397 @_polymorphic_operator
398 def lt(a, b):
399 return a.__lt__(b)
400
401 @_polymorphic_operator
402 def ge(a, b):
403 return a.__ge__(b)
404
405 @_polymorphic_operator
406 def gt(a, b):
407 return a.__gt__(b)
408
409
410 class Polyhedron:
411 """
412 This class implements polyhedrons.
413 """
414
415 def __new__(cls, equalities=None, inequalities=None):
416 if isinstance(equalities, str):
417 if inequalities is not None:
418 raise TypeError('too many arguments')
419 return cls.fromstring(equalities)
420 self = super().__new__(cls)
421 self._equalities = []
422 if equalities is not None:
423 for constraint in equalities:
424 for value in constraint.values():
425 if value.denominator != 1:
426 raise TypeError('non-integer constraint: '
427 '{} == 0'.format(constraint))
428 self._equalities.append(constraint)
429 self._equalities = tuple(self._equalities)
430 self._inequalities = []
431 if inequalities is not None:
432 for constraint in inequalities:
433 for value in constraint.values():
434 if value.denominator != 1:
435 raise TypeError('non-integer constraint: '
436 '{} <= 0'.format(constraint))
437 self._inequalities.append(constraint)
438 self._inequalities = tuple(self._inequalities)
439 self._constraints = self._equalities + self._inequalities
440 self._symbols = set()
441 for constraint in self._constraints:
442 self.symbols.update(constraint.symbols)
443 self._symbols = tuple(sorted(self._symbols))
444 return self
445
446 @classmethod
447 def fromstring(cls, string):
448 raise NotImplementedError
449
450 @property
451 def equalities(self):
452 return self._equalities
453
454 @property
455 def inequalities(self):
456 return self._inequalities
457
458 @property
459 def constraints(self):
460 return self._constraints
461
462 @property
463 def symbols(self):
464 return self._symbols
465
466 @property
467 def dimension(self):
468 return len(self.symbols)
469
470 def __bool__(self):
471 return not self.is_empty()
472
473 def __contains__(self, value):
474 # is the value in the polyhedron?
475 raise NotImplementedError
476
477 def __eq__(self, other):
478 # works correctly when symbols is not passed
479 # should be equal if values are the same even if symbols are different
480 bset = self._toisl()
481 other = other._toisl()
482 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
483
484 def isempty(self):
485 bset = self._toisl()
486 return bool(libisl.isl_basic_set_is_empty(bset))
487
488 def isuniverse(self):
489 bset = self._toisl()
490 return bool(libisl.isl_basic_set_is_universe(bset))
491
492 def isdisjoint(self, other):
493 # return true if the polyhedron has no elements in common with other
494 #symbols = self._symbolunion(other)
495 bset = self._toisl()
496 other = other._toisl()
497 return bool(libisl.isl_set_is_disjoint(bset, other))
498
499 def issubset(self, other):
500 # check if self(bset) is a subset of other
501 symbols = self._symbolunion(other)
502 bset = self._toisl(symbols)
503 other = other._toisl(symbols)
504 return bool(libisl.isl_set_is_strict_subset(other, bset))
505
506 def __le__(self, other):
507 return self.issubset(other)
508
509 def __lt__(self, other):
510 symbols = self._symbolunion(other)
511 bset = self._toisl(symbols)
512 other = other._toisl(symbols)
513 return bool(libisl.isl_set_is_strict_subset(other, bset))
514
515 def issuperset(self, other):
516 # test whether every element in other is in the polyhedron
517 raise NotImplementedError
518
519 def __ge__(self, other):
520 return self.issuperset(other)
521
522 def __gt__(self, other):
523 symbols = self._symbolunion(other)
524 bset = self._toisl(symbols)
525 other = other._toisl(symbols)
526 bool(libisl.isl_set_is_strict_subset(other, bset))
527 raise NotImplementedError
528
529 def union(self, *others):
530 # return a new polyhedron with elements from the polyhedron and all
531 # others (convex union)
532 raise NotImplementedError
533
534 def __or__(self, other):
535 return self.union(other)
536
537 def intersection(self, *others):
538 # return a new polyhedron with elements common to the polyhedron and all
539 # others
540 # a poor man's implementation could be:
541 # equalities = list(self.equalities)
542 # inequalities = list(self.inequalities)
543 # for other in others:
544 # equalities.extend(other.equalities)
545 # inequalities.extend(other.inequalities)
546 # return self.__class__(equalities, inequalities)
547 raise NotImplementedError
548
549 def __and__(self, other):
550 return self.intersection(other)
551
552 def difference(self, other):
553 # return a new polyhedron with elements in the polyhedron that are not in the other
554 symbols = self._symbolunion(other)
555 bset = self._toisl(symbols)
556 other = other._toisl(symbols)
557 difference = libisl.isl_set_subtract(bset, other)
558 return difference
559
560 def __sub__(self, other):
561 return self.difference(other)
562
563 def __str__(self):
564 constraints = []
565 for constraint in self.equalities:
566 constraints.append('{} == 0'.format(constraint))
567 for constraint in self.inequalities:
568 constraints.append('{} >= 0'.format(constraint))
569 return '{{{}}}'.format(', '.join(constraints))
570
571 def __repr__(self):
572 if self.isempty():
573 return 'Empty'
574 elif self.isuniverse():
575 return 'Universe'
576 else:
577 equalities = list(self.equalities)
578 inequalities = list(self.inequalities)
579 return '{}(equalities={!r}, inequalities={!r})' \
580 ''.format(self.__class__.__name__, equalities, inequalities)
581
582 def _symbolunion(self, *others):
583 symbols = set(self.symbols)
584 for other in others:
585 symbols.update(other.symbols)
586 return sorted(symbols)
587
588 def _toisl(self, symbols=None):
589 if symbols is None:
590 symbols = self.symbols
591 dimension = len(symbols)
592 space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
593 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
594 ls = libisl.isl_local_space_from_space(space)
595 for equality in self.equalities:
596 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
597 for symbol, coefficient in equality.coefficients():
598 val = str(coefficient).encode()
599 val = libisl.isl_val_read_from_str(_main_ctx, val)
600 dim = symbols.index(symbol)
601 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
602 if equality.constant != 0:
603 val = str(equality.constant).encode()
604 val = libisl.isl_val_read_from_str(_main_ctx, val)
605 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
606 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
607 for inequality in self.inequalities:
608 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
609 for symbol, coefficient in inequality.coefficients():
610 val = str(coefficient).encode()
611 val = libisl.isl_val_read_from_str(_main_ctx, val)
612 dim = symbols.index(symbol)
613 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
614 if inequality.constant != 0:
615 val = str(ineq.constant).encode()
616 val = libisl.isl_val_read_from_str(_main_ctx, val)
617 cin = libisl.isl_constraint_set_constant_val(cin, val)
618 bset = libisl.isl_basic_set_add_constraint(bset, cin)
619 bset = isl.BasicSet(bset)
620 return bset
621
622 @classmethod
623 def _fromisl(cls, bset, symbols):
624 raise NotImplementedError
625 equalities = ...
626 inequalities = ...
627 return cls(equalities, inequalities)
628 '''takes basic set in isl form and puts back into python version of polyhedron
629 isl example code gives isl form as:
630 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
631 our printer is giving form as:
632 { [i0, i1] : 2i1 >= -2 - i0 } '''
633
634 Empty = eq(0,1)
635 Universe = Polyhedron()
636
637 if __name__ == '__main__':
638 e1 = Expression('2a + 2b + 1')
639 p1 = Polyhedron(equalities=[e1]) # empty
640 e2 = Expression('3x + 2y + 3')
641 p2 = Polyhedron(equalities=[e2]) # not empty
642 print(p1._toisl())
643 print(p2._toisl())