6 from fractions
import Fraction
, gcd
9 from pypol
.isl
import libisl
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
20 def _polymorphic_method(func
):
21 @functools.wraps(func
)
23 if isinstance(b
, Expression
):
25 if isinstance(b
, numbers
.Rational
):
31 def _polymorphic_operator(func
):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func
)
36 if isinstance(a
, numbers
.Rational
):
39 elif isinstance(a
, Expression
):
41 raise TypeError('arguments must be linear expressions')
45 _main_ctx
= isl
.Context()
50 This class implements linear expressions.
53 def __new__(cls
, coefficients
=None, constant
=0):
54 if isinstance(coefficients
, str):
56 raise TypeError('too many arguments')
57 return cls
.fromstring(coefficients
)
58 if isinstance(coefficients
, dict):
59 coefficients
= coefficients
.items()
60 if coefficients
is None:
61 return Constant(constant
)
62 coefficients
= [(symbol
, coefficient
)
63 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
64 if len(coefficients
) == 0:
65 return Constant(constant
)
66 elif len(coefficients
) == 1 and constant
== 0:
67 symbol
, coefficient
= coefficients
[0]
70 self
= object().__new
__(cls
)
71 self
._coefficients
= {}
72 for symbol
, coefficient
in coefficients
:
73 if isinstance(symbol
, Symbol
):
75 elif not isinstance(symbol
, str):
76 raise TypeError('symbols must be strings or Symbol instances')
77 if isinstance(coefficient
, Constant
):
78 coefficient
= coefficient
.constant
79 if not isinstance(coefficient
, numbers
.Rational
):
80 raise TypeError('coefficients must be rational numbers or Constant instances')
81 self
._coefficients
[symbol
] = coefficient
82 if isinstance(constant
, Constant
):
83 constant
= constant
.constant
84 if not isinstance(constant
, numbers
.Rational
):
85 raise TypeError('constant must be a rational number or a Constant instance')
86 self
._constant
= constant
87 self
._symbols
= tuple(sorted(self
._coefficients
))
88 self
._dimension
= len(self
._symbols
)
92 def _fromast(cls
, node
):
93 if isinstance(node
, ast
.Module
):
94 assert len(node
.body
) == 1
95 return cls
._fromast
(node
.body
[0])
96 elif isinstance(node
, ast
.Expr
):
97 return cls
._fromast
(node
.value
)
98 elif isinstance(node
, ast
.Name
):
99 return Symbol(node
.id)
100 elif isinstance(node
, ast
.Num
):
101 return Constant(node
.n
)
102 elif isinstance(node
, ast
.UnaryOp
):
103 if isinstance(node
.op
, ast
.USub
):
104 return -cls
._fromast
(node
.operand
)
105 elif isinstance(node
, ast
.BinOp
):
106 left
= cls
._fromast
(node
.left
)
107 right
= cls
._fromast
(node
.right
)
108 if isinstance(node
.op
, ast
.Add
):
110 elif isinstance(node
.op
, ast
.Sub
):
112 elif isinstance(node
.op
, ast
.Mult
):
114 elif isinstance(node
.op
, ast
.Div
):
116 raise SyntaxError('invalid syntax')
119 def fromstring(cls
, string
):
120 string
= re
.sub(r
'(\d+|\))\s*([^\W\d_]\w*|\()', r
'\1*\2', string
)
121 tree
= ast
.parse(string
, 'eval')
122 return cls
._fromast
(tree
)
130 return self
._dimension
132 def coefficient(self
, symbol
):
133 if isinstance(symbol
, Symbol
):
135 elif not isinstance(symbol
, str):
136 raise TypeError('symbol must be a string or a Symbol instance')
138 return self
._coefficients
[symbol
]
142 __getitem__
= coefficient
144 def coefficients(self
):
145 for symbol
in self
.symbols
:
146 yield symbol
, self
.coefficient(symbol
)
150 return self
._constant
152 def isconstant(self
):
156 for symbol
in self
.symbols
:
157 yield self
.coefficient(symbol
)
173 def __add__(self
, other
):
174 coefficients
= dict(self
.coefficients())
175 for symbol
, coefficient
in other
.coefficients():
176 if symbol
in coefficients
:
177 coefficients
[symbol
] += coefficient
179 coefficients
[symbol
] = coefficient
180 constant
= self
.constant
+ other
.constant
181 return Expression(coefficients
, constant
)
186 def __sub__(self
, other
):
187 coefficients
= dict(self
.coefficients())
188 for symbol
, coefficient
in other
.coefficients():
189 if symbol
in coefficients
:
190 coefficients
[symbol
] -= coefficient
192 coefficients
[symbol
] = -coefficient
193 constant
= self
.constant
- other
.constant
194 return Expression(coefficients
, constant
)
196 def __rsub__(self
, other
):
197 return -(self
- other
)
200 def __mul__(self
, other
):
201 if other
.isconstant():
202 coefficients
= dict(self
.coefficients())
203 for symbol
in coefficients
:
204 coefficients
[symbol
] *= other
.constant
205 constant
= self
.constant
* other
.constant
206 return Expression(coefficients
, constant
)
207 if isinstance(other
, Expression
) and not self
.isconstant():
208 raise ValueError('non-linear expression: '
209 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
210 return NotImplemented
215 def __truediv__(self
, other
):
216 if other
.isconstant():
217 coefficients
= dict(self
.coefficients())
218 for symbol
in coefficients
:
219 coefficients
[symbol
] = \
220 Fraction(coefficients
[symbol
], other
.constant
)
221 constant
= Fraction(self
.constant
, other
.constant
)
222 return Expression(coefficients
, constant
)
223 if isinstance(other
, Expression
):
224 raise ValueError('non-linear expression: '
225 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
226 return NotImplemented
228 def __rtruediv__(self
, other
):
229 if isinstance(other
, self
):
230 if self
.isconstant():
231 constant
= Fraction(other
, self
.constant
)
232 return Expression(constant
=constant
)
234 raise ValueError('non-linear expression: '
235 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
236 return NotImplemented
241 for symbol
in self
.symbols
:
242 coefficient
= self
.coefficient(symbol
)
247 string
+= ' + {}'.format(symbol
)
248 elif coefficient
== -1:
250 string
+= '-{}'.format(symbol
)
252 string
+= ' - {}'.format(symbol
)
255 string
+= '{}*{}'.format(coefficient
, symbol
)
256 elif coefficient
> 0:
257 string
+= ' + {}*{}'.format(coefficient
, symbol
)
259 assert coefficient
< 0
261 string
+= ' - {}*{}'.format(coefficient
, symbol
)
263 constant
= self
.constant
264 if constant
!= 0 and i
== 0:
265 string
+= '{}'.format(constant
)
267 string
+= ' + {}'.format(constant
)
270 string
+= ' - {}'.format(constant
)
275 def _parenstr(self
, always
=False):
277 if not always
and (self
.isconstant() or self
.issymbol()):
280 return '({})'.format(string
)
283 string
= '{}({{'.format(self
.__class
__.__name
__)
284 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
287 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
288 string
+= '}}, {!r})'.format(self
.constant
)
292 def __eq__(self
, other
):
294 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
295 return isinstance(other
, Expression
) and \
296 self
._coefficients
== other
._coefficients
and \
297 self
.constant
== other
.constant
300 return hash((tuple(sorted(self
._coefficients
.items())), self
._constant
))
303 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
304 [value
.denominator
for value
in self
.values()])
308 def _eq(self
, other
):
309 return Polyhedron(equalities
=[(self
- other
)._toint
()])
312 def __le__(self
, other
):
313 return Polyhedron(inequalities
=[(other
- self
)._toint
()])
316 def __lt__(self
, other
):
317 return Polyhedron(inequalities
=[(other
- self
)._toint
() - 1])
320 def __ge__(self
, other
):
321 return Polyhedron(inequalities
=[(self
- other
)._toint
()])
324 def __gt__(self
, other
):
325 return Polyhedron(inequalities
=[(self
- other
)._toint
() - 1])
328 class Constant(Expression
):
330 def __new__(cls
, numerator
=0, denominator
=None):
331 self
= object().__new
__(cls
)
332 if denominator
is None:
333 if isinstance(numerator
, numbers
.Rational
):
334 self
._constant
= numerator
335 elif isinstance(numerator
, Constant
):
336 self
._constant
= numerator
.constant
338 raise TypeError('constant must be a rational number or a Constant instance')
340 self
._constant
= Fraction(numerator
, denominator
)
341 self
._coefficients
= {}
346 def isconstant(self
):
350 return bool(self
.constant
)
353 return '{}({!r})'.format(self
.__class
__.__name
__, self
._constant
)
356 class Symbol(Expression
):
358 def __new__(cls
, name
):
359 if isinstance(name
, Symbol
):
361 elif not isinstance(name
, str):
362 raise TypeError('name must be a string or a Symbol instance')
363 self
= object().__new
__(cls
)
364 self
._coefficients
= {name
: 1}
366 self
._symbols
= tuple(name
)
379 return '{}({!r})'.format(self
.__class
__.__name
__, self
._name
)
382 if isinstance(names
, str):
383 names
= names
.replace(',', ' ').split()
384 return (Symbol(name
) for name
in names
)
387 @_polymorphic_operator
391 @_polymorphic_operator
395 @_polymorphic_operator
399 @_polymorphic_operator
403 @_polymorphic_operator
410 This class implements polyhedrons.
413 def __new__(cls
, equalities
=None, inequalities
=None):
414 if isinstance(equalities
, str):
415 if inequalities
is not None:
416 raise TypeError('too many arguments')
417 return cls
.fromstring(equalities
)
418 self
= super().__new
__(cls
)
419 self
._equalities
= []
420 if equalities
is not None:
421 for constraint
in equalities
:
422 for value
in constraint
.values():
423 if value
.denominator
!= 1:
424 raise TypeError('non-integer constraint: '
425 '{} == 0'.format(constraint
))
426 self
._equalities
.append(constraint
)
427 self
._equalities
= tuple(self
._equalities
)
428 self
._inequalities
= []
429 if inequalities
is not None:
430 for constraint
in inequalities
:
431 for value
in constraint
.values():
432 if value
.denominator
!= 1:
433 raise TypeError('non-integer constraint: '
434 '{} <= 0'.format(constraint
))
435 self
._inequalities
.append(constraint
)
436 self
._inequalities
= tuple(self
._inequalities
)
437 self
._constraints
= self
._equalities
+ self
._inequalities
438 self
._symbols
= set()
439 for constraint
in self
._constraints
:
440 self
.symbols
.update(constraint
.symbols
)
441 self
._symbols
= tuple(sorted(self
._symbols
))
445 def fromstring(cls
, string
):
446 string
= string
.strip()
447 string
= re
.sub(r
'^\{\s*|\s*\}$', '', string
)
448 string
= re
.sub(r
'([^<=>])=([^<=>])', r
'\1==\2', string
)
449 string
= re
.sub(r
'(\d+|\))\s*([^\W\d_]\w*|\()', r
'\1*\2', string
)
452 for cstr
in re
.split(r
',|;|and|&&|/\\|∧', string
, flags
=re
.I
):
453 tree
= ast
.parse(cstr
.strip(), 'eval')
454 if not isinstance(tree
, ast
.Module
) or len(tree
.body
) != 1:
455 raise SyntaxError('invalid syntax')
457 if not isinstance(node
, ast
.Expr
):
458 raise SyntaxError('invalid syntax')
460 if not isinstance(node
, ast
.Compare
):
461 raise SyntaxError('invalid syntax')
462 left
= Expression
._fromast
(node
.left
)
463 for i
in range(len(node
.ops
)):
465 right
= Expression
._fromast
(node
.comparators
[i
])
466 if isinstance(op
, ast
.Lt
):
467 inequalities
.append(right
- left
- 1)
468 elif isinstance(op
, ast
.LtE
):
469 inequalities
.append(right
- left
)
470 elif isinstance(op
, ast
.Eq
):
471 equalities
.append(left
- right
)
472 elif isinstance(op
, ast
.GtE
):
473 inequalities
.append(left
- right
)
474 elif isinstance(op
, ast
.Gt
):
475 inequalities
.append(left
- right
- 1)
477 raise SyntaxError('invalid syntax')
479 return cls(equalities
, inequalities
)
482 def equalities(self
):
483 return self
._equalities
486 def inequalities(self
):
487 return self
._inequalities
490 def constraints(self
):
491 return self
._constraints
499 return len(self
.symbols
)
502 return not self
.is_empty()
504 def __contains__(self
, value
):
505 # is the value in the polyhedron?
506 raise NotImplementedError
508 def __eq__(self
, other
):
509 # works correctly when symbols is not passed
510 # should be equal if values are the same even if symbols are different
512 other
= other
._toisl
()
513 return bool(libisl
.isl_basic_set_plain_is_equal(bset
, other
))
517 return bool(libisl
.isl_basic_set_is_empty(bset
))
519 def isuniverse(self
):
521 return bool(libisl
.isl_basic_set_is_universe(bset
))
523 def isdisjoint(self
, other
):
524 # return true if the polyhedron has no elements in common with other
525 #symbols = self._symbolunion(other)
527 other
= other
._toisl
()
528 return bool(libisl
.isl_set_is_disjoint(bset
, other
))
530 def issubset(self
, other
):
531 # check if self(bset) is a subset of other
532 symbols
= self
._symbolunion
(other
)
533 bset
= self
._toisl
(symbols
)
534 other
= other
._toisl
(symbols
)
535 return bool(libisl
.isl_set_is_strict_subset(other
, bset
))
537 def __le__(self
, other
):
538 return self
.issubset(other
)
540 def __lt__(self
, other
):
541 symbols
= self
._symbolunion
(other
)
542 bset
= self
._toisl
(symbols
)
543 other
= other
._toisl
(symbols
)
544 return bool(libisl
.isl_set_is_strict_subset(other
, bset
))
546 def issuperset(self
, other
):
547 # test whether every element in other is in the polyhedron
548 raise NotImplementedError
550 def __ge__(self
, other
):
551 return self
.issuperset(other
)
553 def __gt__(self
, other
):
554 symbols
= self
._symbolunion
(other
)
555 bset
= self
._toisl
(symbols
)
556 other
= other
._toisl
(symbols
)
557 bool(libisl
.isl_set_is_strict_subset(other
, bset
))
558 raise NotImplementedError
560 def union(self
, *others
):
561 # return a new polyhedron with elements from the polyhedron and all
562 # others (convex union)
563 raise NotImplementedError
565 def __or__(self
, other
):
566 return self
.union(other
)
568 def intersection(self
, *others
):
569 # return a new polyhedron with elements common to the polyhedron and all
571 # a poor man's implementation could be:
572 # equalities = list(self.equalities)
573 # inequalities = list(self.inequalities)
574 # for other in others:
575 # equalities.extend(other.equalities)
576 # inequalities.extend(other.inequalities)
577 # return self.__class__(equalities, inequalities)
578 raise NotImplementedError
580 def __and__(self
, other
):
581 return self
.intersection(other
)
583 def difference(self
, other
):
584 # return a new polyhedron with elements in the polyhedron that are not in the other
585 symbols
= self
._symbolunion
(other
)
586 bset
= self
._toisl
(symbols
)
587 other
= other
._toisl
(symbols
)
588 difference
= libisl
.isl_set_subtract(bset
, other
)
591 def __sub__(self
, other
):
592 return self
.difference(other
)
596 for constraint
in self
.equalities
:
597 constraints
.append('{} == 0'.format(constraint
))
598 for constraint
in self
.inequalities
:
599 constraints
.append('{} >= 0'.format(constraint
))
600 return '{{{}}}'.format(', '.join(constraints
))
605 elif self
.isuniverse():
608 equalities
= list(self
.equalities
)
609 inequalities
= list(self
.inequalities
)
610 return '{}(equalities={!r}, inequalities={!r})' \
611 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
613 def _symbolunion(self
, *others
):
614 symbols
= set(self
.symbols
)
616 symbols
.update(other
.symbols
)
617 return sorted(symbols
)
619 def _toisl(self
, symbols
=None):
621 symbols
= self
.symbols
622 dimension
= len(symbols
)
623 space
= libisl
.isl_space_set_alloc(_main_ctx
, 0, dimension
)
624 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
625 ls
= libisl
.isl_local_space_from_space(space
)
626 for equality
in self
.equalities
:
627 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
628 for symbol
, coefficient
in equality
.coefficients():
629 val
= str(coefficient
).encode()
630 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
631 dim
= symbols
.index(symbol
)
632 ceq
= libisl
.isl_constraint_set_coefficient_val(ceq
, libisl
.isl_dim_set
, dim
, val
)
633 if equality
.constant
!= 0:
634 val
= str(equality
.constant
).encode()
635 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
636 ceq
= libisl
.isl_constraint_set_constant_val(ceq
, val
)
637 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
638 for inequality
in self
.inequalities
:
639 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
640 for symbol
, coefficient
in inequality
.coefficients():
641 val
= str(coefficient
).encode()
642 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
643 dim
= symbols
.index(symbol
)
644 cin
= libisl
.isl_constraint_set_coefficient_val(cin
, libisl
.isl_dim_set
, dim
, val
)
645 if inequality
.constant
!= 0:
646 val
= str(inequality
.constant
).encode()
647 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
648 cin
= libisl
.isl_constraint_set_constant_val(cin
, val
)
649 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
650 bset
= isl
.BasicSet(bset
)
654 def _fromisl(cls
, bset
, symbols
):
655 raise NotImplementedError
658 return cls(equalities
, inequalities
)
659 '''takes basic set in isl form and puts back into python version of polyhedron
660 isl example code gives isl form as:
661 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
662 our printer is giving form as:
663 { [i0, i1] : 2i1 >= -2 - i0 } '''
666 Universe
= Polyhedron()
668 if __name__
== '__main__':
669 p1
= Polyhedron('2a + 2b + 1 == 0') # empty
671 p2
= Polyhedron('3x + 2y + 3 == 0') # not empty