07c62936ff31d00b07cf675f5e98c810b1d8107f
1 import ctypes
, ctypes
.util
5 from fractions
import Fraction
, gcd
7 from . import isl
, islhelper
8 from .isl
import libisl
, Context
, BasicSet
13 'constant', 'symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
20 def _polymorphic_method(func
):
21 @functools.wraps(func
)
23 if isinstance(b
, Expression
):
25 if isinstance(b
, numbers
.Rational
):
31 def _polymorphic_operator(func
):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func
)
36 if isinstance(a
, numbers
.Rational
):
39 elif isinstance(a
, Expression
):
41 raise TypeError('arguments must be linear expressions')
50 This class implements linear expressions.
53 def __new__(cls
, coefficients
=None, constant
=0):
54 if isinstance(coefficients
, str):
56 raise TypeError('too many arguments')
57 return cls
.fromstring(coefficients
)
58 self
= super().__new
__(cls
)
59 self
._coefficients
= {}
60 if isinstance(coefficients
, dict):
61 coefficients
= coefficients
.items()
62 if coefficients
is not None:
63 for symbol
, coefficient
in coefficients
:
64 if isinstance(symbol
, Expression
) and symbol
.issymbol():
66 elif not isinstance(symbol
, str):
67 raise TypeError('symbols must be strings')
68 if not isinstance(coefficient
, numbers
.Rational
):
69 raise TypeError('coefficients must be rational numbers')
71 self
._coefficients
[symbol
] = coefficient
72 if not isinstance(constant
, numbers
.Rational
):
73 raise TypeError('constant must be a rational number')
74 self
._constant
= constant
79 yield from sorted(self
._coefficients
)
83 return len(list(self
.symbols()))
85 def coefficient(self
, symbol
):
86 if isinstance(symbol
, Expression
) and symbol
.issymbol():
88 elif not isinstance(symbol
, str):
89 raise TypeError('symbol must be a string')
91 return self
._coefficients
[symbol
]
95 __getitem__
= coefficient
98 def coefficients(self
):
99 for symbol
in self
.symbols():
100 yield symbol
, self
.coefficient(symbol
)
104 return self
._constant
106 def isconstant(self
):
107 return len(self
._coefficients
) == 0
110 for symbol
in self
.symbols():
111 yield self
.coefficient(symbol
)
114 def values_int(self
):
115 for symbol
in self
.symbols():
116 return self
.coefficient(symbol
)
117 return int(self
.constant
)
120 if not self
.issymbol():
121 raise ValueError('not a symbol: {}'.format(self
))
122 for symbol
in self
.symbols():
126 return len(self
._coefficients
) == 1 and self
._constant
== 0
129 return (not self
.isconstant()) or bool(self
.constant
)
138 def __add__(self
, other
):
139 coefficients
= dict(self
.coefficients
)
140 for symbol
, coefficient
in other
.coefficients
:
141 if symbol
in coefficients
:
142 coefficients
[symbol
] += coefficient
144 coefficients
[symbol
] = coefficient
145 constant
= self
.constant
+ other
.constant
146 return Expression(coefficients
, constant
)
151 def __sub__(self
, other
):
152 coefficients
= dict(self
.coefficients
)
153 for symbol
, coefficient
in other
.coefficients
:
154 if symbol
in coefficients
:
155 coefficients
[symbol
] -= coefficient
157 coefficients
[symbol
] = -coefficient
158 constant
= self
.constant
- other
.constant
159 return Expression(coefficients
, constant
)
161 def __rsub__(self
, other
):
162 return -(self
- other
)
165 def __mul__(self
, other
):
166 if other
.isconstant():
167 coefficients
= dict(self
.coefficients
)
168 for symbol
in coefficients
:
169 coefficients
[symbol
] *= other
.constant
170 constant
= self
.constant
* other
.constant
171 return Expression(coefficients
, constant
)
172 if isinstance(other
, Expression
) and not self
.isconstant():
173 raise ValueError('non-linear expression: '
174 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
175 return NotImplemented
180 def __truediv__(self
, other
):
181 if other
.isconstant():
182 coefficients
= dict(self
.coefficients())
183 for symbol
in coefficients
:
184 coefficients
[symbol
] = \
185 Fraction(coefficients
[symbol
], other
.constant
)
186 constant
= Fraction(self
.constant
, other
.constant
)
187 return Expression(coefficients
, constant
)
188 if isinstance(other
, Expression
):
189 raise ValueError('non-linear expression: '
190 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
191 return NotImplemented
193 def __rtruediv__(self
, other
):
194 if isinstance(other
, self
):
195 if self
.isconstant():
196 constant
= Fraction(other
, self
.constant
)
197 return Expression(constant
=constant
)
199 raise ValueError('non-linear expression: '
200 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
201 return NotImplemented
205 symbols
= sorted(self
.symbols())
207 for symbol
in symbols
:
208 coefficient
= self
[symbol
]
213 string
+= ' + {}'.format(symbol
)
214 elif coefficient
== -1:
216 string
+= '-{}'.format(symbol
)
218 string
+= ' - {}'.format(symbol
)
221 string
+= '{}*{}'.format(coefficient
, symbol
)
222 elif coefficient
> 0:
223 string
+= ' + {}*{}'.format(coefficient
, symbol
)
225 assert coefficient
< 0
227 string
+= ' - {}*{}'.format(coefficient
, symbol
)
229 constant
= self
.constant
230 if constant
!= 0 and i
== 0:
231 string
+= '{}'.format(constant
)
233 string
+= ' + {}'.format(constant
)
236 string
+= ' - {}'.format(constant
)
241 def _parenstr(self
, always
=False):
243 if not always
and (self
.isconstant() or self
.issymbol()):
246 return '({})'.format(string
)
249 string
= '{}({{'.format(self
.__class
__.__name
__)
250 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients
):
253 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
254 string
+= '}}, {!r})'.format(self
.constant
)
258 def fromstring(cls
, string
):
259 raise NotImplementedError
262 def __eq__(self
, other
):
264 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
265 return isinstance(other
, Expression
) and \
266 self
._coefficients
== other
._coefficients
and \
267 self
.constant
== other
.constant
270 return hash((self
._coefficients
, self
._constant
))
273 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
274 [value
.denominator
for value
in self
.values()])
278 def _eq(self
, other
):
279 return Polyhedron(equalities
=[(self
- other
)._canonify
()])
282 def __le__(self
, other
):
283 return Polyhedron(inequalities
=[(other
- self
)._canonify
()])
286 def __lt__(self
, other
):
287 return Polyhedron(inequalities
=[(other
- self
)._canonify
() - 1])
290 def __ge__(self
, other
):
291 return Polyhedron(inequalities
=[(self
- other
)._canonify
()])
294 def __gt__(self
, other
):
295 return Polyhedron(inequalities
=[(self
- other
)._canonify
() - 1])
298 def constant(numerator
=0, denominator
=None):
299 if denominator
is None and isinstance(numerator
, numbers
.Rational
):
300 return Expression(constant
=numerator
)
302 return Expression(constant
=Fraction(numerator
, denominator
))
305 if not isinstance(name
, str):
306 raise TypeError('name must be a string')
307 return Expression(coefficients
={name
: 1})
310 if isinstance(names
, str):
311 names
= names
.replace(',', ' ').split()
312 return (symbol(name
) for name
in names
)
315 @_polymorphic_operator
319 @_polymorphic_operator
323 @_polymorphic_operator
327 @_polymorphic_operator
331 @_polymorphic_operator
338 This class implements polyhedrons.
341 def __new__(cls
, equalities
=None, inequalities
=None):
342 if isinstance(equalities
, str):
343 if inequalities
is not None:
344 raise TypeError('too many arguments')
345 return cls
.fromstring(equalities
)
346 self
= super().__new
__(cls
)
347 self
._equalities
= []
348 if equalities
is not None:
349 for constraint
in equalities
:
350 for value
in constraint
.values():
351 if value
.denominator
!= 1:
352 raise TypeError('non-integer constraint: '
353 '{} == 0'.format(constraint
))
354 self
._equalities
.append(constraint
)
355 self
._inequalities
= []
356 if inequalities
is not None:
357 for constraint
in inequalities
:
358 for value
in constraint
.values():
359 if value
.denominator
!= 1:
360 raise TypeError('non-integer constraint: '
361 '{} <= 0'.format(constraint
))
362 self
._inequalities
.append(constraint
)
366 def equalities(self
):
367 yield from self
._equalities
370 def inequalities(self
):
371 yield from self
._inequalities
375 return self
._constant
377 def isconstant(self
):
378 return len(self
._coefficients
) == 0
381 return bool(libisl
.isl_basic_set_is_empty(self
._bset
))
383 def constraints(self
):
384 yield from self
.equalities
385 yield from self
.inequalities
389 for constraint
in self
.constraints():
390 s
.update(constraint
.symbols())
395 return len(self
.symbols())
398 # return false if the polyhedron is empty, true otherwise
399 if self
._equalities
or self
._inequalities
:
404 def __contains__(self
, value
):
405 # is the value in the polyhedron?
406 raise NotImplementedError
408 def __eq__(self
, other
):
409 raise NotImplementedError
414 def isuniverse(self
):
415 return self
== universe
417 def isdisjoint(self
, other
):
418 # return true if the polyhedron has no elements in common with other
419 raise NotImplementedError
421 def issubset(self
, other
):
422 raise NotImplementedError
424 def __le__(self
, other
):
425 return self
.issubset(other
)
427 def __lt__(self
, other
):
428 raise NotImplementedError
430 def issuperset(self
, other
):
431 # test whether every element in other is in the polyhedron
433 if value
== self
.constraints():
437 raise NotImplementedError
439 def __ge__(self
, other
):
440 return self
.issuperset(other
)
442 def __gt__(self
, other
):
443 raise NotImplementedError
445 def union(self
, *others
):
446 # return a new polyhedron with elements from the polyhedron and all
447 # others (convex union)
448 raise NotImplementedError
450 def __or__(self
, other
):
451 return self
.union(other
)
453 def intersection(self
, *others
):
454 # return a new polyhedron with elements common to the polyhedron and all
456 # a poor man's implementation could be:
457 # equalities = list(self.equalities)
458 # inequalities = list(self.inequalities)
459 # for other in others:
460 # equalities.extend(other.equalities)
461 # inequalities.extend(other.inequalities)
462 # return self.__class__(equalities, inequalities)
463 raise NotImplementedError
465 def __and__(self
, other
):
466 return self
.intersection(other
)
468 def difference(self
, *others
):
469 # return a new polyhedron with elements in the polyhedron that are not
471 raise NotImplementedError
473 def __sub__(self
, other
):
474 return self
.difference(other
)
478 for constraint
in self
.equalities
:
479 constraints
.append('{} == 0'.format(constraint
))
480 for constraint
in self
.inequalities
:
481 constraints
.append('{} >= 0'.format(constraint
))
482 return '{{{}}}'.format(', '.join(constraints
))
485 equalities
= list(self
.equalities
)
486 inequalities
= list(self
.inequalities
)
487 return '{}(equalities={!r}, inequalities={!r})' \
488 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
491 def fromstring(cls
, string
):
492 raise NotImplementedError
494 def _symbolunion(self
, *others
):
495 symbols
= set(self
.symbols())
497 symbols
.update(other
.symbols())
498 return sorted(symbols
)
500 def _to_isl(self
, symbols
=None):
502 symbols
= self
.symbols()
503 num_coefficients
= len(symbols
)
504 space
= libisl
.isl_space_set_alloc(_main_ctx
, 0, num_coefficients
)
505 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
506 ls
= libisl
.isl_local_space_from_space(space
)
507 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
508 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
509 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
510 if list(self
.equalities
): #check if any equalities exist
511 for eq
in self
.equalities
:
512 coeff_eq
= dict(eq
.coefficients
)
515 ceq
= libisl
.isl_constraint_set_constant_si(ceq
, value
)
517 num
= coeff_eq
.get(eq
)
518 iden
= symbols
.index(eq
)
519 ceq
= libisl
.isl_constraint_set_coefficient_si(ceq
, islhelper
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
520 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
521 if list(self
.inequalities
): #check if any inequalities exist
522 for ineq
in self
.inequalities
:
523 coeff_in
= dict(ineq
.coefficients
)
525 value
= ineq
.constant
526 cin
= libisl
.isl_constraint_set_constant_si(cin
, value
)
527 for ineq
in coeff_in
:
528 num
= coeff_in
.get(ineq
)
529 iden
= symbols
.index(ineq
)
530 cin
= libisl
.isl_constraint_set_coefficient_si(cin
, islhelper
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
531 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
532 bset
= BasicSet(bset
)
535 def from_isl(self
, bset
):
536 '''takes basic set in isl form and puts back into python version of polyhedron
537 isl example code gives isl form as:
538 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
539 our printer is giving form as:
540 b'{ [i0] : 1 = 0 }' '''
543 constraints
= libisl
.isl_basic_set_equalities_matrix(bset
, 3)
544 elif self
._inequalities
:
545 constraints
= libisl
.isl_basic_set_inequalities_matrix(bset
, 3)
549 empty
= None #eq(0,1)
550 universe
= None #Polyhedron()
552 if __name__
== '__main__':
553 ex1
= Expression(coefficients
={'a': 1, 'x': 2}, constant
=2)
554 ex2
= Expression(coefficients
={'a': 3 , 'b': 2}, constant
=3)
555 p
= Polyhedron(inequalities
=[ex1
, ex2
])