7544633812ddd419946dec4c8b46e57f50664786
1 import ctypes
, ctypes
.util
5 from fractions
import Fraction
, gcd
8 from .isl
import libisl
13 'constant', 'symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
20 def _polymorphic_method(func
):
21 @functools.wraps(func
)
23 if isinstance(b
, Expression
):
25 if isinstance(b
, numbers
.Rational
):
31 def _polymorphic_operator(func
):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func
)
36 if isinstance(a
, numbers
.Rational
):
39 elif isinstance(a
, Expression
):
41 raise TypeError('arguments must be linear expressions')
45 _main_ctx
= isl
.Context()
50 This class implements linear expressions.
53 def __new__(cls
, coefficients
=None, constant
=0):
54 if isinstance(coefficients
, str):
56 raise TypeError('too many arguments')
57 return cls
.fromstring(coefficients
)
58 self
= super().__new
__(cls
)
59 self
._coefficients
= {}
60 if isinstance(coefficients
, dict):
61 coefficients
= coefficients
.items()
62 if coefficients
is not None:
63 for symbol
, coefficient
in coefficients
:
64 if isinstance(symbol
, Expression
) and symbol
.issymbol():
66 elif not isinstance(symbol
, str):
67 raise TypeError('symbols must be strings')
68 if not isinstance(coefficient
, numbers
.Rational
):
69 raise TypeError('coefficients must be rational numbers')
71 self
._coefficients
[symbol
] = coefficient
72 if not isinstance(constant
, numbers
.Rational
):
73 raise TypeError('constant must be a rational number')
74 self
._constant
= constant
75 self
._symbols
= tuple(sorted(self
._coefficients
))
76 self
._dimension
= len(self
._symbols
)
85 return self
._dimension
87 def coefficient(self
, symbol
):
88 if isinstance(symbol
, Expression
) and symbol
.issymbol():
90 elif not isinstance(symbol
, str):
91 raise TypeError('symbol must be a string')
93 return self
._coefficients
[symbol
]
97 __getitem__
= coefficient
99 def coefficients(self
):
100 for symbol
in self
.symbols
:
101 yield symbol
, self
.coefficient(symbol
)
105 return self
._constant
107 def isconstant(self
):
108 return len(self
._coefficients
) == 0
111 for symbol
in self
.symbols
:
112 yield self
.coefficient(symbol
)
115 def values_int(self
):
116 for symbol
in self
.symbols
:
117 return self
.coefficient(symbol
)
118 return int(self
.constant
)
122 if not self
.issymbol():
123 raise ValueError('not a symbol: {}'.format(self
))
124 for symbol
in self
.symbols
:
128 return len(self
._coefficients
) == 1 and self
._constant
== 0
131 return (not self
.isconstant()) or bool(self
.constant
)
140 def __add__(self
, other
):
141 coefficients
= dict(self
.coefficients())
142 for symbol
, coefficient
in other
.coefficients
:
143 if symbol
in coefficients
:
144 coefficients
[symbol
] += coefficient
146 coefficients
[symbol
] = coefficient
147 constant
= self
.constant
+ other
.constant
148 return Expression(coefficients
, constant
)
153 def __sub__(self
, other
):
154 coefficients
= dict(self
.coefficients())
155 for symbol
, coefficient
in other
.coefficients
:
156 if symbol
in coefficients
:
157 coefficients
[symbol
] -= coefficient
159 coefficients
[symbol
] = -coefficient
160 constant
= self
.constant
- other
.constant
161 return Expression(coefficients
, constant
)
163 def __rsub__(self
, other
):
164 return -(self
- other
)
167 def __mul__(self
, other
):
168 if other
.isconstant():
169 coefficients
= dict(self
.coefficients())
170 for symbol
in coefficients
:
171 coefficients
[symbol
] *= other
.constant
172 constant
= self
.constant
* other
.constant
173 return Expression(coefficients
, constant
)
174 if isinstance(other
, Expression
) and not self
.isconstant():
175 raise ValueError('non-linear expression: '
176 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
177 return NotImplemented
182 def __truediv__(self
, other
):
183 if other
.isconstant():
184 coefficients
= dict(self
.coefficients())
185 for symbol
in coefficients
:
186 coefficients
[symbol
] = \
187 Fraction(coefficients
[symbol
], other
.constant
)
188 constant
= Fraction(self
.constant
, other
.constant
)
189 return Expression(coefficients
, constant
)
190 if isinstance(other
, Expression
):
191 raise ValueError('non-linear expression: '
192 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
193 return NotImplemented
195 def __rtruediv__(self
, other
):
196 if isinstance(other
, self
):
197 if self
.isconstant():
198 constant
= Fraction(other
, self
.constant
)
199 return Expression(constant
=constant
)
201 raise ValueError('non-linear expression: '
202 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
203 return NotImplemented
208 for symbol
in symbols
:
209 coefficient
= self
[symbol
]
214 string
+= ' + {}'.format(symbol
)
215 elif coefficient
== -1:
217 string
+= '-{}'.format(symbol
)
219 string
+= ' - {}'.format(symbol
)
222 string
+= '{}*{}'.format(coefficient
, symbol
)
223 elif coefficient
> 0:
224 string
+= ' + {}*{}'.format(coefficient
, symbol
)
226 assert coefficient
< 0
228 string
+= ' - {}*{}'.format(coefficient
, symbol
)
230 constant
= self
.constant
231 if constant
!= 0 and i
== 0:
232 string
+= '{}'.format(constant
)
234 string
+= ' + {}'.format(constant
)
237 string
+= ' - {}'.format(constant
)
242 def _parenstr(self
, always
=False):
244 if not always
and (self
.isconstant() or self
.issymbol()):
247 return '({})'.format(string
)
250 string
= '{}({{'.format(self
.__class
__.__name
__)
251 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
254 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
255 string
+= '}}, {!r})'.format(self
.constant
)
259 def fromstring(cls
, string
):
260 raise NotImplementedError
263 def __eq__(self
, other
):
265 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
266 return isinstance(other
, Expression
) and \
267 self
._coefficients
== other
._coefficients
and \
268 self
.constant
== other
.constant
271 return hash((self
._coefficients
, self
._constant
))
274 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
275 [value
.denominator
for value
in self
.values()])
279 def _eq(self
, other
):
280 return Polyhedron(equalities
=[(self
- other
)._toint
()])
283 def __le__(self
, other
):
284 return Polyhedron(inequalities
=[(other
- self
)._toint
()])
287 def __lt__(self
, other
):
288 return Polyhedron(inequalities
=[(other
- self
)._toint
() - 1])
291 def __ge__(self
, other
):
292 return Polyhedron(inequalities
=[(self
- other
)._toint
()])
295 def __gt__(self
, other
):
296 return Polyhedron(inequalities
=[(self
- other
)._toint
() - 1])
299 def constant(numerator
=0, denominator
=None):
300 if denominator
is None and isinstance(numerator
, numbers
.Rational
):
301 return Expression(constant
=numerator
)
303 return Expression(constant
=Fraction(numerator
, denominator
))
306 if not isinstance(name
, str):
307 raise TypeError('name must be a string')
308 return Expression(coefficients
={name
: 1})
311 if isinstance(names
, str):
312 names
= names
.replace(',', ' ').split()
313 return (symbol(name
) for name
in names
)
316 @_polymorphic_operator
320 @_polymorphic_operator
324 @_polymorphic_operator
328 @_polymorphic_operator
332 @_polymorphic_operator
339 This class implements polyhedrons.
342 def __new__(cls
, equalities
=None, inequalities
=None):
343 if isinstance(equalities
, str):
344 if inequalities
is not None:
345 raise TypeError('too many arguments')
346 return cls
.fromstring(equalities
)
347 self
= super().__new
__(cls
)
348 self
._equalities
= []
349 if equalities
is not None:
350 for constraint
in equalities
:
351 for value
in constraint
.values():
352 if value
.denominator
!= 1:
353 raise TypeError('non-integer constraint: '
354 '{} == 0'.format(constraint
))
355 self
._equalities
.append(constraint
)
356 self
._equalities
= tuple(self
._equalities
)
357 self
._inequalities
= []
358 if inequalities
is not None:
359 for constraint
in inequalities
:
360 for value
in constraint
.values():
361 if value
.denominator
!= 1:
362 raise TypeError('non-integer constraint: '
363 '{} <= 0'.format(constraint
))
364 self
._inequalities
.append(constraint
)
365 self
._inequalities
= tuple(self
._inequalities
)
366 self
._constraints
= self
._equalities
+ self
._inequalities
367 self
._symbols
= set()
368 for constraint
in self
._constraints
:
369 self
.symbols
.update(constraint
.symbols
)
370 self
._symbols
= tuple(sorted(self
._symbols
))
374 def equalities(self
):
375 return self
._equalities
378 def inequalities(self
):
379 return self
._inequalities
383 return self
._constant
385 def isconstant(self
):
386 return len(self
._coefficients
) == 0
389 return bool(libisl
.isl_basic_set_is_empty(self
._bset
))
392 def constraints(self
):
393 return self
._constraints
401 return len(self
.symbols
)
404 return not self
.is_empty()
406 def __contains__(self
, value
):
407 # is the value in the polyhedron?
408 raise NotImplementedError
410 def __eq__(self
, other
):
411 raise NotImplementedError
416 def isuniverse(self
):
417 return self
== universe
419 def isdisjoint(self
, other
):
420 # return true if the polyhedron has no elements in common with other
421 raise NotImplementedError
423 def issubset(self
, other
):
424 raise NotImplementedError
426 def __le__(self
, other
):
427 return self
.issubset(other
)
429 def __lt__(self
, other
):
430 raise NotImplementedError
432 def issuperset(self
, other
):
433 # test whether every element in other is in the polyhedron
434 raise NotImplementedError
436 def __ge__(self
, other
):
437 return self
.issuperset(other
)
439 def __gt__(self
, other
):
440 raise NotImplementedError
442 def union(self
, *others
):
443 # return a new polyhedron with elements from the polyhedron and all
444 # others (convex union)
445 raise NotImplementedError
447 def __or__(self
, other
):
448 return self
.union(other
)
450 def intersection(self
, *others
):
451 # return a new polyhedron with elements common to the polyhedron and all
453 # a poor man's implementation could be:
454 # equalities = list(self.equalities)
455 # inequalities = list(self.inequalities)
456 # for other in others:
457 # equalities.extend(other.equalities)
458 # inequalities.extend(other.inequalities)
459 # return self.__class__(equalities, inequalities)
460 raise NotImplementedError
462 def __and__(self
, other
):
463 return self
.intersection(other
)
465 def difference(self
, *others
):
466 # return a new polyhedron with elements in the polyhedron that are not
468 raise NotImplementedError
470 def __sub__(self
, other
):
471 return self
.difference(other
)
475 for constraint
in self
.equalities
:
476 constraints
.append('{} == 0'.format(constraint
))
477 for constraint
in self
.inequalities
:
478 constraints
.append('{} >= 0'.format(constraint
))
479 return '{{{}}}'.format(', '.join(constraints
))
482 equalities
= list(self
.equalities
)
483 inequalities
= list(self
.inequalities
)
484 return '{}(equalities={!r}, inequalities={!r})' \
485 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
488 def fromstring(cls
, string
):
489 raise NotImplementedError
491 def _symbolunion(self
, *others
):
492 symbols
= set(self
.symbols
)
494 symbols
.update(other
.symbols
)
495 return sorted(symbols
)
497 def _to_isl(self
, symbols
=None):
499 symbols
= self
.symbols
500 num_coefficients
= len(symbols
)
501 space
= libisl
.isl_space_set_alloc(_main_ctx
, 0, num_coefficients
)
502 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
503 ls
= libisl
.isl_local_space_from_space(space
)
504 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
505 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
506 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
507 if list(self
.equalities
): #check if any equalities exist
508 for eq
in self
.equalities
:
509 coeff_eq
= dict(eq
.coefficients())
512 ceq
= libisl
.isl_constraint_set_constant_si(ceq
, value
)
514 num
= coeff_eq
.get(eq
)
515 iden
= symbols
.index(eq
)
516 ceq
= libisl
.isl_constraint_set_coefficient_si(ceq
, libisl
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
517 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
518 if list(self
.inequalities
): #check if any inequalities exist
519 for ineq
in self
.inequalities
:
520 coeff_in
= dict(ineq
.coefficients())
522 value
= ineq
.constant
523 cin
= libisl
.isl_constraint_set_constant_si(cin
, value
)
524 for ineq
in coeff_in
:
525 num
= coeff_in
.get(ineq
)
526 iden
= symbols
.index(ineq
)
527 cin
= libisl
.isl_constraint_set_coefficient_si(cin
, libisl
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
528 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
529 bset
= isl
.BasicSet(bset
)
533 def from_isl(cls
, bset
):
534 '''takes basic set in isl form and puts back into python version of polyhedron
535 isl example code gives isl form as:
536 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
537 our printer is giving form as:
538 b'{ [i0] : 1 = 0 }' '''
539 raise NotImplementedError
542 return cls(equalities
, inequalities
)
544 # if self._equalities:
545 # constraints = libisl.isl_basic_set_equalities_matrix(bset, 3)
546 # elif self._inequalities:
547 # constraints = libisl.isl_basic_set_inequalities_matrix(bset, 3)
551 empty
= None #eq(0,1)
552 universe
= None #Polyhedron()
555 if __name__
== '__main__':
556 ex1
= Expression(coefficients
={'a': 1, 'x': 2}, constant
=2)
557 ex2
= Expression(coefficients
={'a': 3 , 'b': 2}, constant
=3)
558 p
= Polyhedron(inequalities
=[ex1
, ex2
])