1 import ast
2 import functools
3 import numbers
4 import re
6 from fractions import Fraction, gcd
8 from . import isl
9 from .isl import libisl
12 __all__ = [
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'Empty', 'Universe'
17 ]
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = Constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = Constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
45 _main_ctx = isl.Context()
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
53 __slots__ = (
54 '_coefficients',
55 '_constant',
56 '_symbols',
57 '_dimension',
58 )
60 def __new__(cls, coefficients=None, constant=0):
61 if isinstance(coefficients, str):
62 if constant:
63 raise TypeError('too many arguments')
64 return cls.fromstring(coefficients)
65 if isinstance(coefficients, dict):
66 coefficients = coefficients.items()
67 if coefficients is None:
68 return Constant(constant)
69 coefficients = [(symbol, coefficient)
70 for symbol, coefficient in coefficients if coefficient != 0]
71 if len(coefficients) == 0:
72 return Constant(constant)
73 elif len(coefficients) == 1 and constant == 0:
74 symbol, coefficient = coefficients[0]
75 if coefficient == 1:
76 return Symbol(symbol)
77 self = object().__new__(cls)
78 self._coefficients = {}
79 for symbol, coefficient in coefficients:
80 if isinstance(symbol, Symbol):
81 symbol = symbol.name
82 elif not isinstance(symbol, str):
83 raise TypeError('symbols must be strings or Symbol instances')
84 if isinstance(coefficient, Constant):
85 coefficient = coefficient.constant
86 if not isinstance(coefficient, numbers.Rational):
87 raise TypeError('coefficients must be rational numbers or Constant instances')
88 self._coefficients[symbol] = coefficient
89 if isinstance(constant, Constant):
90 constant = constant.constant
91 if not isinstance(constant, numbers.Rational):
92 raise TypeError('constant must be a rational number or a Constant instance')
93 self._constant = constant
94 self._symbols = tuple(sorted(self._coefficients))
95 self._dimension = len(self._symbols)
96 return self
98 @classmethod
99 def _fromast(cls, node):
100 if isinstance(node, ast.Module) and len(node.body) == 1:
101 return cls._fromast(node.body[0])
102 elif isinstance(node, ast.Expr):
103 return cls._fromast(node.value)
104 elif isinstance(node, ast.Name):
105 return Symbol(node.id)
106 elif isinstance(node, ast.Num):
107 return Constant(node.n)
108 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
109 return -cls._fromast(node.operand)
110 elif isinstance(node, ast.BinOp):
111 left = cls._fromast(node.left)
112 right = cls._fromast(node.right)
114 return left + right
115 elif isinstance(node.op, ast.Sub):
116 return left - right
117 elif isinstance(node.op, ast.Mult):
118 return left * right
119 elif isinstance(node.op, ast.Div):
120 return left / right
121 raise SyntaxError('invalid syntax')
123 @classmethod
124 def fromstring(cls, string):
125 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
126 tree = ast.parse(string, 'eval')
127 return cls._fromast(tree)
129 @property
130 def symbols(self):
131 return self._symbols
133 @property
134 def dimension(self):
135 return self._dimension
137 def coefficient(self, symbol):
138 if isinstance(symbol, Symbol):
139 symbol = str(symbol)
140 elif not isinstance(symbol, str):
141 raise TypeError('symbol must be a string or a Symbol instance')
142 try:
143 return self._coefficients[symbol]
144 except KeyError:
145 return 0
147 __getitem__ = coefficient
149 def coefficients(self):
150 for symbol in self.symbols:
151 yield symbol, self.coefficient(symbol)
153 @property
154 def constant(self):
155 return self._constant
157 def isconstant(self):
158 return False
160 def values(self):
161 for symbol in self.symbols:
162 yield self.coefficient(symbol)
163 yield self.constant
165 def issymbol(self):
166 return False
168 def __bool__(self):
169 return True
171 def __pos__(self):
172 return self
174 def __neg__(self):
175 return self * -1
177 @_polymorphic_method
179 coefficients = dict(self.coefficients())
180 for symbol, coefficient in other.coefficients():
181 if symbol in coefficients:
182 coefficients[symbol] += coefficient
183 else:
184 coefficients[symbol] = coefficient
185 constant = self.constant + other.constant
186 return Expression(coefficients, constant)
190 @_polymorphic_method
191 def __sub__(self, other):
192 coefficients = dict(self.coefficients())
193 for symbol, coefficient in other.coefficients():
194 if symbol in coefficients:
195 coefficients[symbol] -= coefficient
196 else:
197 coefficients[symbol] = -coefficient
198 constant = self.constant - other.constant
199 return Expression(coefficients, constant)
201 def __rsub__(self, other):
202 return -(self - other)
204 @_polymorphic_method
205 def __mul__(self, other):
206 if other.isconstant():
207 coefficients = dict(self.coefficients())
208 for symbol in coefficients:
209 coefficients[symbol] *= other.constant
210 constant = self.constant * other.constant
211 return Expression(coefficients, constant)
212 if isinstance(other, Expression) and not self.isconstant():
213 raise ValueError('non-linear expression: '
214 '{} * {}'.format(self._parenstr(), other._parenstr()))
215 return NotImplemented
217 __rmul__ = __mul__
219 @_polymorphic_method
220 def __truediv__(self, other):
221 if other.isconstant():
222 coefficients = dict(self.coefficients())
223 for symbol in coefficients:
224 coefficients[symbol] = \
225 Fraction(coefficients[symbol], other.constant)
226 constant = Fraction(self.constant, other.constant)
227 return Expression(coefficients, constant)
228 if isinstance(other, Expression):
229 raise ValueError('non-linear expression: '
230 '{} / {}'.format(self._parenstr(), other._parenstr()))
231 return NotImplemented
233 def __rtruediv__(self, other):
234 if isinstance(other, self):
235 if self.isconstant():
236 constant = Fraction(other, self.constant)
237 return Expression(constant=constant)
238 else:
239 raise ValueError('non-linear expression: '
240 '{} / {}'.format(other._parenstr(), self._parenstr()))
241 return NotImplemented
243 def __str__(self):
244 string = ''
245 i = 0
246 for symbol in self.symbols:
247 coefficient = self.coefficient(symbol)
248 if coefficient == 1:
249 if i == 0:
250 string += symbol
251 else:
252 string += ' + {}'.format(symbol)
253 elif coefficient == -1:
254 if i == 0:
255 string += '-{}'.format(symbol)
256 else:
257 string += ' - {}'.format(symbol)
258 else:
259 if i == 0:
260 string += '{}*{}'.format(coefficient, symbol)
261 elif coefficient > 0:
262 string += ' + {}*{}'.format(coefficient, symbol)
263 else:
264 assert coefficient < 0
265 coefficient *= -1
266 string += ' - {}*{}'.format(coefficient, symbol)
267 i += 1
268 constant = self.constant
269 if constant != 0 and i == 0:
270 string += '{}'.format(constant)
271 elif constant > 0:
272 string += ' + {}'.format(constant)
273 elif constant < 0:
274 constant *= -1
275 string += ' - {}'.format(constant)
276 if string == '':
277 string = '0'
278 return string
280 def _parenstr(self, always=False):
281 string = str(self)
282 if not always and (self.isconstant() or self.issymbol()):
283 return string
284 else:
285 return '({})'.format(string)
287 def __repr__(self):
288 string = '{}({{'.format(self.__class__.__name__)
289 for i, (symbol, coefficient) in enumerate(self.coefficients()):
290 if i != 0:
291 string += ', '
292 string += '{!r}: {!r}'.format(symbol, coefficient)
293 string += '}}, {!r})'.format(self.constant)
294 return string
296 @_polymorphic_method
297 def __eq__(self, other):
298 # "normal" equality
299 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
300 return isinstance(other, Expression) and \
301 self._coefficients == other._coefficients and \
302 self.constant == other.constant
304 def __hash__(self):
305 return hash((tuple(sorted(self._coefficients.items())), self._constant))
307 def _toint(self):
308 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
309 [value.denominator for value in self.values()])
310 return self * lcm
312 @_polymorphic_method
313 def _eq(self, other):
314 return Polyhedron(equalities=[(self - other)._toint()])
316 @_polymorphic_method
317 def __le__(self, other):
318 return Polyhedron(inequalities=[(other - self)._toint()])
320 @_polymorphic_method
321 def __lt__(self, other):
322 return Polyhedron(inequalities=[(other - self)._toint() - 1])
324 @_polymorphic_method
325 def __ge__(self, other):
326 return Polyhedron(inequalities=[(self - other)._toint()])
328 @_polymorphic_method
329 def __gt__(self, other):
330 return Polyhedron(inequalities=[(self - other)._toint() - 1])
332 @classmethod
333 def fromsympy(cls, expr):
334 import sympy
335 coefficients = {}
336 constant = 0
337 for symbol, coefficient in expr.as_coefficients_dict().items():
338 coefficient = Fraction(coefficient.p, coefficient.q)
339 if symbol == sympy.S.One:
340 constant = coefficient
341 elif isinstance(symbol, sympy.Symbol):
342 symbol = symbol.name
343 coefficients[symbol] = coefficient
344 else:
345 raise ValueError('non-linear expression: {!r}'.format(expr))
346 return cls(coefficients, constant)
348 def tosympy(self):
349 import sympy
350 expr = 0
351 for symbol, coefficient in self.coefficients():
352 term = coefficient * sympy.Symbol(symbol)
353 expr += term
354 expr += self.constant
355 return expr
358 class Constant(Expression):
360 def __new__(cls, numerator=0, denominator=None):
361 self = object().__new__(cls)
362 if denominator is None:
363 if isinstance(numerator, numbers.Rational):
364 self._constant = numerator
365 elif isinstance(numerator, Constant):
366 self._constant = numerator.constant
367 else:
368 raise TypeError('constant must be a rational number or a Constant instance')
369 else:
370 self._constant = Fraction(numerator, denominator)
371 self._coefficients = {}
372 self._symbols = ()
373 self._dimension = 0
374 return self
376 def isconstant(self):
377 return True
379 def __bool__(self):
380 return bool(self.constant)
382 def __repr__(self):
383 if self.constant.denominator == 1:
384 return '{}({!r})'.format(self.__class__.__name__, self.constant)
385 else:
386 return '{}({!r}, {!r})'.format(self.__class__.__name__,
387 self.constant.numerator, self.constant.denominator)
389 @classmethod
390 def fromsympy(cls, expr):
391 import sympy
392 if isinstance(expr, sympy.Rational):
393 return cls(expr.p, expr.q)
394 elif isinstance(expr, numbers.Rational):
395 return cls(expr)
396 else:
397 raise TypeError('expr must be a sympy.Rational instance')
400 class Symbol(Expression):
402 __slots__ = Expression.__slots__ + (
403 '_name',
404 )
406 def __new__(cls, name):
407 if isinstance(name, Symbol):
408 name = name.name
409 elif not isinstance(name, str):
410 raise TypeError('name must be a string or a Symbol instance')
411 self = object().__new__(cls)
412 self._coefficients = {name: 1}
413 self._constant = 0
414 self._symbols = tuple(name)
415 self._name = name
416 self._dimension = 1
417 return self
419 @property
420 def name(self):
421 return self._name
423 def issymbol(self):
424 return True
426 def __repr__(self):
427 return '{}({!r})'.format(self.__class__.__name__, self._name)
429 @classmethod
430 def fromsympy(cls, expr):
431 import sympy
432 if isinstance(expr, sympy.Symbol):
433 return cls(expr.name)
434 else:
435 raise TypeError('expr must be a sympy.Symbol instance')
438 def symbols(names):
439 if isinstance(names, str):
440 names = names.replace(',', ' ').split()
441 return (Symbol(name) for name in names)
444 @_polymorphic_operator
445 def eq(a, b):
446 return a.__eq__(b)
448 @_polymorphic_operator
449 def le(a, b):
450 return a.__le__(b)
452 @_polymorphic_operator
453 def lt(a, b):
454 return a.__lt__(b)
456 @_polymorphic_operator
457 def ge(a, b):
458 return a.__ge__(b)
460 @_polymorphic_operator
461 def gt(a, b):
462 return a.__gt__(b)
465 class Polyhedron:
466 """
467 This class implements polyhedrons.
468 """
470 __slots__ = (
471 '_equalities',
472 '_inequalities',
473 '_constraints',
474 '_symbols',
475 )
477 def __new__(cls, equalities=None, inequalities=None):
478 if isinstance(equalities, str):
479 if inequalities is not None:
480 raise TypeError('too many arguments')
481 return cls.fromstring(equalities)
482 self = super().__new__(cls)
483 self._equalities = []
484 if equalities is not None:
485 for constraint in equalities:
486 for value in constraint.values():
487 if value.denominator != 1:
488 raise TypeError('non-integer constraint: '
489 '{} == 0'.format(constraint))
490 self._equalities.append(constraint)
491 self._equalities = tuple(self._equalities)
492 self._inequalities = []
493 if inequalities is not None:
494 for constraint in inequalities:
495 for value in constraint.values():
496 if value.denominator != 1:
497 raise TypeError('non-integer constraint: '
498 '{} <= 0'.format(constraint))
499 self._inequalities.append(constraint)
500 self._inequalities = tuple(self._inequalities)
501 self._constraints = self._equalities + self._inequalities
502 self._symbols = set()
503 for constraint in self._constraints:
504 self.symbols.update(constraint.symbols)
505 self._symbols = tuple(sorted(self._symbols))
506 return self
508 @classmethod
509 def _fromast(cls, node):
510 if isinstance(node, ast.Module) and len(node.body) == 1:
511 return cls._fromast(node.body[0])
512 elif isinstance(node, ast.Expr):
513 return cls._fromast(node.value)
514 elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd):
515 equalities1, inequalities1 = cls._fromast(node.left)
516 equalities2, inequalities2 = cls._fromast(node.right)
517 equalities = equalities1 + equalities2
518 inequalities = inequalities1 + inequalities2
519 return equalities, inequalities
520 elif isinstance(node, ast.Compare):
521 equalities = []
522 inequalities = []
523 left = Expression._fromast(node.left)
524 for i in range(len(node.ops)):
525 op = node.ops[i]
526 right = Expression._fromast(node.comparators[i])
527 if isinstance(op, ast.Lt):
528 inequalities.append(right - left - 1)
529 elif isinstance(op, ast.LtE):
530 inequalities.append(right - left)
531 elif isinstance(op, ast.Eq):
532 equalities.append(left - right)
533 elif isinstance(op, ast.GtE):
534 inequalities.append(left - right)
535 elif isinstance(op, ast.Gt):
536 inequalities.append(left - right - 1)
537 else:
538 break
539 left = right
540 else:
541 return equalities, inequalities
542 raise SyntaxError('invalid syntax')
544 @classmethod
545 def fromstring(cls, string):
546 string = string.strip()
547 string = re.sub(r'^\{\s*|\s*\}\$', '', string)
548 string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
549 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
550 tokens = re.split(r',|;|and|&&|/\\|âˆ§', string, flags=re.I)
551 tokens = ['({})'.format(token) for token in tokens]
552 string = ' & '.join(tokens)
553 tree = ast.parse(string, 'eval')
554 equalities, inequalities = cls._fromast(tree)
555 return cls(equalities, inequalities)
557 @property
558 def equalities(self):
559 return self._equalities
561 @property
562 def inequalities(self):
563 return self._inequalities
565 @property
566 def constraints(self):
567 return self._constraints
569 @property
570 def symbols(self):
571 return self._symbols
573 @property
574 def dimension(self):
575 return len(self.symbols)
577 def __bool__(self):
578 return not self.is_empty()
580 def __contains__(self, value):
581 # is the value in the polyhedron?
582 raise NotImplementedError
584 def __eq__(self, other):
585 # works correctly when symbols is not passed
586 # should be equal if values are the same even if symbols are different
587 bset = self._toisl()
588 other = other._toisl()
589 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
591 def isempty(self):
592 bset = self._toisl()
593 return bool(libisl.isl_basic_set_is_empty(bset))
595 def isuniverse(self):
596 bset = self._toisl()
597 return bool(libisl.isl_basic_set_is_universe(bset))
599 def isdisjoint(self, other):
600 # return true if the polyhedron has no elements in common with other
601 #symbols = self._symbolunion(other)
602 bset = self._toisl()
603 other = other._toisl()
604 return bool(libisl.isl_set_is_disjoint(bset, other))
606 def issubset(self, other):
607 # check if self(bset) is a subset of other
608 symbols = self._symbolunion(other)
609 bset = self._toisl(symbols)
610 other = other._toisl(symbols)
611 return bool(libisl.isl_set_is_strict_subset(other, bset))
613 def __le__(self, other):
614 return self.issubset(other)
616 def __lt__(self, other):
617 symbols = self._symbolunion(other)
618 bset = self._toisl(symbols)
619 other = other._toisl(symbols)
620 return bool(libisl.isl_set_is_strict_subset(other, bset))
622 def issuperset(self, other):
623 # test whether every element in other is in the polyhedron
624 raise NotImplementedError
626 def __ge__(self, other):
627 return self.issuperset(other)
629 def __gt__(self, other):
630 symbols = self._symbolunion(other)
631 bset = self._toisl(symbols)
632 other = other._toisl(symbols)
633 bool(libisl.isl_set_is_strict_subset(other, bset))
634 raise NotImplementedError
636 def union(self, *others):
637 # return a new polyhedron with elements from the polyhedron and all
638 # others (convex union)
639 raise NotImplementedError
641 def __or__(self, other):
642 return self.union(other)
644 def intersection(self, *others):
645 # return a new polyhedron with elements common to the polyhedron and all
646 # others
647 # a poor man's implementation could be:
648 # equalities = list(self.equalities)
649 # inequalities = list(self.inequalities)
650 # for other in others:
651 # equalities.extend(other.equalities)
652 # inequalities.extend(other.inequalities)
653 # return self.__class__(equalities, inequalities)
654 raise NotImplementedError
656 def __and__(self, other):
657 return self.intersection(other)
659 def difference(self, other):
660 # return a new polyhedron with elements in the polyhedron that are not in the other
661 symbols = self._symbolunion(other)
662 bset = self._toisl(symbols)
663 other = other._toisl(symbols)
664 difference = libisl.isl_set_subtract(bset, other)
665 return difference
667 def __sub__(self, other):
668 return self.difference(other)
670 def __str__(self):
671 constraints = []
672 for constraint in self.equalities:
673 constraints.append('{} == 0'.format(constraint))
674 for constraint in self.inequalities:
675 constraints.append('{} >= 0'.format(constraint))
676 return '{{{}}}'.format(', '.join(constraints))
678 def __repr__(self):
679 if self.isempty():
680 return 'Empty'
681 elif self.isuniverse():
682 return 'Universe'
683 else:
684 equalities = list(self.equalities)
685 inequalities = list(self.inequalities)
686 return '{}(equalities={!r}, inequalities={!r})' \
687 ''.format(self.__class__.__name__, equalities, inequalities)
689 @classmethod
690 def _fromsympy(cls, expr):
691 import sympy
692 equalities = []
693 inequalities = []
694 if expr.func == sympy.And:
695 for arg in expr.args:
696 arg_eqs, arg_ins = cls._fromsympy(arg)
697 equalities.extend(arg_eqs)
698 inequalities.extend(arg_ins)
699 elif expr.func == sympy.Eq:
700 expr = Expression.fromsympy(expr.args[0] - expr.args[1])
701 equalities.append(expr)
702 else:
703 if expr.func == sympy.Lt:
704 expr = Expression.fromsympy(expr.args[1] - expr.args[0] - 1)
705 elif expr.func == sympy.Le:
706 expr = Expression.fromsympy(expr.args[1] - expr.args[0])
707 elif expr.func == sympy.Ge:
708 expr = Expression.fromsympy(expr.args[0] - expr.args[1])
709 elif expr.func == sympy.Gt:
710 expr = Expression.fromsympy(expr.args[0] - expr.args[1] - 1)
711 else:
712 raise ValueError('non-polyhedral expression: {!r}'.format(expr))
713 inequalities.append(expr)
714 return equalities, inequalities
716 @classmethod
717 def fromsympy(cls, expr):
718 import sympy
719 equalities, inequalities = cls._fromsympy(expr)
720 return cls(equalities, inequalities)
722 def tosympy(self):
723 import sympy
724 constraints = []
725 for equality in self.equalities:
726 constraints.append(sympy.Eq(equality.tosympy(), 0))
727 for inequality in self.inequalities:
728 constraints.append(sympy.Ge(inequality.tosympy(), 0))
729 return sympy.And(*constraints)
731 def _symbolunion(self, *others):
732 symbols = set(self.symbols)
733 for other in others:
734 symbols.update(other.symbols)
735 return sorted(symbols)
737 def _toisl(self, symbols=None):
738 if symbols is None:
739 symbols = self.symbols
740 dimension = len(symbols)
741 space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
742 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
743 ls = libisl.isl_local_space_from_space(space)
744 for equality in self.equalities:
745 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
746 for symbol, coefficient in equality.coefficients():
747 val = str(coefficient).encode()
749 dim = symbols.index(symbol)
750 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
751 if equality.constant != 0:
752 val = str(equality.constant).encode()
754 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
756 for inequality in self.inequalities:
757 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
758 for symbol, coefficient in inequality.coefficients():
759 val = str(coefficient).encode()
761 dim = symbols.index(symbol)
762 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
763 if inequality.constant != 0:
764 val = str(inequality.constant).encode()
766 cin = libisl.isl_constraint_set_constant_val(cin, val)
768 bset = isl.BasicSet(bset)
769 return bset
771 @classmethod
772 def _fromisl(cls, bset, symbols):
773 raise NotImplementedError
774 equalities = ...
775 inequalities = ...
776 return cls(equalities, inequalities)
777 '''takes basic set in isl form and puts back into python version of polyhedron
778 isl example code gives isl form as:
779 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
780 our printer is giving form as:
781 { [i0, i1] : 2i1 >= -2 - i0 } '''
783 Empty = eq(0,1)
785 Universe = Polyhedron()
788 if __name__ == '__main__':
789 #p = Polyhedron('2a + 2b + 1 == 0') # empty
790 p = Polyhedron('3x + 2y + 3 == 0, y == 0') # not empty
791 ip = p._toisl()
792 print(ip)
793 print(ip.constraints())