348294ce808d0306890ffb548ab28c628448b9e6
1 import ctypes
, ctypes
.util
5 from fractions
import Fraction
, gcd
7 from . import isl
, islhelper
8 from .isl
import libisl
, Context
, BasicSet
13 'constant', 'symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
20 def _polymorphic_method(func
):
21 @functools.wraps(func
)
23 if isinstance(b
, Expression
):
25 if isinstance(b
, numbers
.Rational
):
31 def _polymorphic_operator(func
):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func
)
36 if isinstance(a
, numbers
.Rational
):
39 elif isinstance(a
, Expression
):
41 raise TypeError('arguments must be linear expressions')
47 This class implements linear expressions.
50 def __new__(cls
, coefficients
=None, constant
=0):
51 if isinstance(coefficients
, str):
53 raise TypeError('too many arguments')
54 return cls
.fromstring(coefficients
)
55 self
= super().__new
__(cls
)
56 self
._coefficients
= {}
57 if isinstance(coefficients
, dict):
58 coefficients
= coefficients
.items()
59 if coefficients
is not None:
60 for symbol
, coefficient
in coefficients
:
61 if isinstance(symbol
, Expression
) and symbol
.issymbol():
63 elif not isinstance(symbol
, str):
64 raise TypeError('symbols must be strings')
65 if not isinstance(coefficient
, numbers
.Rational
):
66 raise TypeError('coefficients must be rational numbers')
68 self
._coefficients
[symbol
] = coefficient
69 if not isinstance(constant
, numbers
.Rational
):
70 raise TypeError('constant must be a rational number')
71 self
._constant
= constant
76 yield from sorted(self
._coefficients
)
80 return len(list(self
.symbols()))
82 def coefficient(self
, symbol
):
83 if isinstance(symbol
, Expression
) and symbol
.issymbol():
85 elif not isinstance(symbol
, str):
86 raise TypeError('symbol must be a string')
88 return self
._coefficients
[symbol
]
92 __getitem__
= coefficient
95 def coefficients(self
):
96 for symbol
in self
.symbols():
97 yield symbol
, self
.coefficient(symbol
)
101 return self
._constant
103 def isconstant(self
):
104 return len(self
._coefficients
) == 0
107 for symbol
in self
.symbols():
108 yield self
.coefficient(symbol
)
111 def values_int(self
):
112 for symbol
in self
.symbols():
113 return self
.coefficient(symbol
)
114 return int(self
.constant
)
117 if not self
.issymbol():
118 raise ValueError('not a symbol: {}'.format(self
))
119 for symbol
in self
.symbols():
123 return len(self
._coefficients
) == 1 and self
._constant
== 0
126 return (not self
.isconstant()) or bool(self
.constant
)
135 def __add__(self
, other
):
136 coefficients
= dict(self
.coefficients
)
137 for symbol
, coefficient
in other
.coefficients
:
138 if symbol
in coefficients
:
139 coefficients
[symbol
] += coefficient
141 coefficients
[symbol
] = coefficient
142 constant
= self
.constant
+ other
.constant
143 return Expression(coefficients
, constant
)
148 def __sub__(self
, other
):
149 coefficients
= dict(self
.coefficients
)
150 for symbol
, coefficient
in other
.coefficients
:
151 if symbol
in coefficients
:
152 coefficients
[symbol
] -= coefficient
154 coefficients
[symbol
] = -coefficient
155 constant
= self
.constant
- other
.constant
156 return Expression(coefficients
, constant
)
158 def __rsub__(self
, other
):
159 return -(self
- other
)
162 def __mul__(self
, other
):
163 if other
.isconstant():
164 coefficients
= dict(self
.coefficients
)
165 for symbol
in coefficients
:
166 coefficients
[symbol
] *= other
.constant
167 constant
= self
.constant
* other
.constant
168 return Expression(coefficients
, constant
)
169 if isinstance(other
, Expression
) and not self
.isconstant():
170 raise ValueError('non-linear expression: '
171 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
172 return NotImplemented
177 def __truediv__(self
, other
):
178 if other
.isconstant():
179 coefficients
= dict(self
.coefficients())
180 for symbol
in coefficients
:
181 coefficients
[symbol
] = \
182 Fraction(coefficients
[symbol
], other
.constant
)
183 constant
= Fraction(self
.constant
, other
.constant
)
184 return Expression(coefficients
, constant
)
185 if isinstance(other
, Expression
):
186 raise ValueError('non-linear expression: '
187 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
188 return NotImplemented
190 def __rtruediv__(self
, other
):
191 if isinstance(other
, self
):
192 if self
.isconstant():
193 constant
= Fraction(other
, self
.constant
)
194 return Expression(constant
=constant
)
196 raise ValueError('non-linear expression: '
197 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
198 return NotImplemented
202 symbols
= sorted(self
.symbols())
204 for symbol
in symbols
:
205 coefficient
= self
[symbol
]
210 string
+= ' + {}'.format(symbol
)
211 elif coefficient
== -1:
213 string
+= '-{}'.format(symbol
)
215 string
+= ' - {}'.format(symbol
)
218 string
+= '{}*{}'.format(coefficient
, symbol
)
219 elif coefficient
> 0:
220 string
+= ' + {}*{}'.format(coefficient
, symbol
)
222 assert coefficient
< 0
224 string
+= ' - {}*{}'.format(coefficient
, symbol
)
226 constant
= self
.constant
227 if constant
!= 0 and i
== 0:
228 string
+= '{}'.format(constant
)
230 string
+= ' + {}'.format(constant
)
233 string
+= ' - {}'.format(constant
)
238 def _parenstr(self
, always
=False):
240 if not always
and (self
.isconstant() or self
.issymbol()):
243 return '({})'.format(string
)
246 string
= '{}({{'.format(self
.__class
__.__name
__)
247 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients
):
250 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
251 string
+= '}}, {!r})'.format(self
.constant
)
255 def fromstring(cls
, string
):
256 raise NotImplementedError
259 def __eq__(self
, other
):
261 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
262 return isinstance(other
, Expression
) and \
263 self
._coefficients
== other
._coefficients
and \
264 self
.constant
== other
.constant
267 return hash((self
._coefficients
, self
._constant
))
270 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
271 [value
.denominator
for value
in self
.values()])
275 def _eq(self
, other
):
276 return Polyhedron(equalities
=[(self
- other
)._canonify
()])
279 def __le__(self
, other
):
280 return Polyhedron(inequalities
=[(other
- self
)._canonify
()])
283 def __lt__(self
, other
):
284 return Polyhedron(inequalities
=[(other
- self
)._canonify
() - 1])
287 def __ge__(self
, other
):
288 return Polyhedron(inequalities
=[(self
- other
)._canonify
()])
291 def __gt__(self
, other
):
292 return Polyhedron(inequalities
=[(self
- other
)._canonify
() - 1])
295 def constant(numerator
=0, denominator
=None):
296 if denominator
is None and isinstance(numerator
, numbers
.Rational
):
297 return Expression(constant
=numerator
)
299 return Expression(constant
=Fraction(numerator
, denominator
))
302 if not isinstance(name
, str):
303 raise TypeError('name must be a string')
304 return Expression(coefficients
={name
: 1})
307 if isinstance(names
, str):
308 names
= names
.replace(',', ' ').split()
309 return (symbol(name
) for name
in names
)
312 @_polymorphic_operator
316 @_polymorphic_operator
320 @_polymorphic_operator
324 @_polymorphic_operator
328 @_polymorphic_operator
335 This class implements polyhedrons.
338 def __new__(cls
, equalities
=None, inequalities
=None):
339 if isinstance(equalities
, str):
340 if inequalities
is not None:
341 raise TypeError('too many arguments')
342 return cls
.fromstring(equalities
)
343 self
= super().__new
__(cls
)
344 self
._equalities
= []
345 if equalities
is not None:
346 for constraint
in equalities
:
347 for value
in constraint
.values():
348 if value
.denominator
!= 1:
349 raise TypeError('non-integer constraint: '
350 '{} == 0'.format(constraint
))
351 self
._equalities
.append(constraint
)
352 self
._inequalities
= []
353 if inequalities
is not None:
354 for constraint
in inequalities
:
355 for value
in constraint
.values():
356 if value
.denominator
!= 1:
357 raise TypeError('non-integer constraint: '
358 '{} <= 0'.format(constraint
))
359 self
._inequalities
.append(constraint
)
363 def equalities(self
):
364 yield from self
._equalities
367 def inequalities(self
):
368 yield from self
._inequalities
372 return self
._constant
374 def isconstant(self
):
375 return len(self
._coefficients
) == 0
378 return bool(libisl
.isl_basic_set_is_empty(self
._bset
))
380 def constraints(self
):
381 yield from self
.equalities
382 yield from self
.inequalities
386 for constraint
in self
.constraints():
387 s
.update(constraint
.symbols())
392 return len(self
.symbols())
395 # return false if the polyhedron is empty, true otherwise
396 if self
._equalities
or self
._inequalities
:
401 def __contains__(self
, value
):
402 # is the value in the polyhedron?
403 raise NotImplementedError
405 def __eq__(self
, other
):
406 raise NotImplementedError
411 def isuniverse(self
):
412 return self
== universe
414 def isdisjoint(self
, other
):
415 # return true if the polyhedron has no elements in common with other
416 raise NotImplementedError
418 def issubset(self
, other
):
419 raise NotImplementedError
421 def __le__(self
, other
):
422 return self
.issubset(other
)
424 def __lt__(self
, other
):
425 raise NotImplementedError
427 def issuperset(self
, other
):
428 # test whether every element in other is in the polyhedron
430 if value
== self
.constraints():
434 raise NotImplementedError
436 def __ge__(self
, other
):
437 return self
.issuperset(other
)
439 def __gt__(self
, other
):
440 raise NotImplementedError
442 def union(self
, *others
):
443 # return a new polyhedron with elements from the polyhedron and all
444 # others (convex union)
445 raise NotImplementedError
447 def __or__(self
, other
):
448 return self
.union(other
)
450 def intersection(self
, *others
):
451 # return a new polyhedron with elements common to the polyhedron and all
453 # a poor man's implementation could be:
454 # equalities = list(self.equalities)
455 # inequalities = list(self.inequalities)
456 # for other in others:
457 # equalities.extend(other.equalities)
458 # inequalities.extend(other.inequalities)
459 # return self.__class__(equalities, inequalities)
460 raise NotImplementedError
462 def __and__(self
, other
):
463 return self
.intersection(other
)
465 def difference(self
, *others
):
466 # return a new polyhedron with elements in the polyhedron that are not
468 raise NotImplementedError
470 def __sub__(self
, other
):
471 return self
.difference(other
)
475 for constraint
in self
.equalities
:
476 constraints
.append('{} == 0'.format(constraint
))
477 for constraint
in self
.inequalities
:
478 constraints
.append('{} >= 0'.format(constraint
))
479 return '{{{}}}'.format(', '.join(constraints
))
482 equalities
= list(self
.equalities
)
483 inequalities
= list(self
.inequalities
)
484 return '{}(equalities={!r}, inequalities={!r})' \
485 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
488 def fromstring(cls
, string
):
489 raise NotImplementedError
491 def _symbolunion(self
, *others
):
492 symbols
= set(self
.symbols())
494 symbols
.update(other
.symbols())
495 return sorted(symbols
)
497 def _to_isl(self
, symbols
=None):
499 symbols
= self
.symbols()
500 num_coefficients
= len(symbols
)
502 space
= libisl
.isl_space_set_alloc(ctx
, 0, num_coefficients
)
503 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
504 ls
= libisl
.isl_local_space_from_space(space
)
505 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
506 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
507 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
508 if list(self
.equalities
): #check if any equalities exist
509 for eq
in self
.equalities
:
510 coeff_eq
= dict(eq
.coefficients
)
513 ceq
= libisl
.isl_constraint_set_constant_si(ceq
, value
)
515 num
= coeff_eq
.get(eq
)
516 iden
= symbols
.index(eq
)
517 ceq
= libisl
.isl_constraint_set_coefficient_si(ceq
, islhelper
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
518 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
519 if list(self
.inequalities
): #check if any inequalities exist
520 for ineq
in self
.inequalities
:
521 coeff_in
= dict(ineq
.coefficients
)
523 value
= ineq
.constant
524 cin
= libisl
.isl_constraint_set_constant_si(cin
, value
)
525 for ineq
in coeff_in
:
526 num
= coeff_in
.get(ineq
)
527 iden
= symbols
.index(ineq
)
528 cin
= libisl
.isl_constraint_set_coefficient_si(cin
, islhelper
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
529 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
530 bset
= BasicSet(bset
)
533 def from_isl(self
, bset
):
534 '''takes basic set in isl form and puts back into python version of polyhedron
535 isl example code gives isl form as:
536 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
537 our printer is giving form as:
538 b'{ [i0] : 1 = 0 }' '''
541 constraints
= libisl
.isl_basic_set_equalities_matrix(bset
, 3)
542 elif self
._inequalities
:
543 constraints
= libisl
.isl_basic_set_inequalities_matrix(bset
, 3)
547 empty
= None #eq(0,1)
548 universe
= None #Polyhedron()
550 if __name__
== '__main__':
551 ex1
= Expression(coefficients
={'a': 1, 'x': 2}, constant
=2)
552 ex2
= Expression(coefficients
={'a': 3 , 'b': 2}, constant
=3)
553 p
= Polyhedron(inequalities
=[ex1
, ex2
])