331d3af6f6d0f01208af69bdfd50fb0de732055f
1 import ctypes
, ctypes
.util
5 from fractions
import Fraction
, gcd
7 from . import isl
, islhelper
8 from .isl
import libisl
, Context
13 'constant', 'symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
20 def _polymorphic_method(func
):
21 @functools.wraps(func
)
23 if isinstance(b
, Expression
):
25 if isinstance(b
, numbers
.Rational
):
31 def _polymorphic_operator(func
):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func
)
36 if isinstance(a
, numbers
.Rational
):
39 elif isinstance(a
, Expression
):
41 raise TypeError('arguments must be linear expressions')
47 This class implements linear expressions.
50 def __new__(cls
, coefficients
=None, constant
=0):
51 if isinstance(coefficients
, str):
53 raise TypeError('too many arguments')
54 return cls
.fromstring(coefficients
)
55 self
= super().__new
__(cls
)
56 self
._coefficients
= {}
57 if isinstance(coefficients
, dict):
58 coefficients
= coefficients
.items()
59 if coefficients
is not None:
60 for symbol
, coefficient
in coefficients
:
61 if isinstance(symbol
, Expression
) and symbol
.issymbol():
63 elif not isinstance(symbol
, str):
64 raise TypeError('symbols must be strings')
65 if not isinstance(coefficient
, numbers
.Rational
):
66 raise TypeError('coefficients must be rational numbers')
68 self
._coefficients
[symbol
] = coefficient
69 if not isinstance(constant
, numbers
.Rational
):
70 raise TypeError('constant must be a rational number')
71 self
._constant
= constant
76 yield from sorted(self
._coefficients
)
80 return len(list(self
.symbols()))
82 def coefficient(self
, symbol
):
83 if isinstance(symbol
, Expression
) and symbol
.issymbol():
85 elif not isinstance(symbol
, str):
86 raise TypeError('symbol must be a string')
88 return self
._coefficients
[symbol
]
92 __getitem__
= coefficient
95 def coefficients(self
):
96 for symbol
in self
.symbols():
97 yield symbol
, self
.coefficient(symbol
)
101 return self
._constant
103 def isconstant(self
):
104 return len(self
._coefficients
) == 0
107 for symbol
in self
.symbols():
108 yield self
.coefficient(symbol
)
111 def values_int(self
):
112 for symbol
in self
.symbols():
113 return self
.coefficient(symbol
)
114 return int(self
.constant
)
117 if not self
.issymbol():
118 raise ValueError('not a symbol: {}'.format(self
))
119 for symbol
in self
.symbols():
123 return len(self
._coefficients
) == 1 and self
._constant
== 0
126 return (not self
.isconstant()) or bool(self
.constant
)
135 def __add__(self
, other
):
136 coefficients
= dict(self
.coefficients
)
137 for symbol
, coefficient
in other
.coefficients
:
138 if symbol
in coefficients
:
139 coefficients
[symbol
] += coefficient
141 coefficients
[symbol
] = coefficient
142 constant
= self
.constant
+ other
.constant
143 return Expression(coefficients
, constant
)
148 def __sub__(self
, other
):
149 coefficients
= dict(self
.coefficients
)
150 for symbol
, coefficient
in other
.coefficients
:
151 if symbol
in coefficients
:
152 coefficients
[symbol
] -= coefficient
154 coefficients
[symbol
] = -coefficient
155 constant
= self
.constant
- other
.constant
156 return Expression(coefficients
, constant
)
158 def __rsub__(self
, other
):
159 return -(self
- other
)
162 def __mul__(self
, other
):
163 if other
.isconstant():
164 coefficients
= dict(self
.coefficients
)
165 for symbol
in coefficients
:
166 coefficients
[symbol
] *= other
.constant
167 constant
= self
.constant
* other
.constant
168 return Expression(coefficients
, constant
)
169 if isinstance(other
, Expression
) and not self
.isconstant():
170 raise ValueError('non-linear expression: '
171 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
172 return NotImplemented
177 def __truediv__(self
, other
):
178 if other
.isconstant():
179 coefficients
= dict(self
.coefficients())
180 for symbol
in coefficients
:
181 coefficients
[symbol
] = \
182 Fraction(coefficients
[symbol
], other
.constant
)
183 constant
= Fraction(self
.constant
, other
.constant
)
184 return Expression(coefficients
, constant
)
185 if isinstance(other
, Expression
):
186 raise ValueError('non-linear expression: '
187 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
188 return NotImplemented
190 def __rtruediv__(self
, other
):
191 if isinstance(other
, self
):
192 if self
.isconstant():
193 constant
= Fraction(other
, self
.constant
)
194 return Expression(constant
=constant
)
196 raise ValueError('non-linear expression: '
197 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
198 return NotImplemented
202 symbols
= sorted(self
.symbols())
204 for symbol
in symbols
:
205 coefficient
= self
[symbol
]
210 string
+= ' + {}'.format(symbol
)
211 elif coefficient
== -1:
213 string
+= '-{}'.format(symbol
)
215 string
+= ' - {}'.format(symbol
)
218 string
+= '{}*{}'.format(coefficient
, symbol
)
219 elif coefficient
> 0:
220 string
+= ' + {}*{}'.format(coefficient
, symbol
)
222 assert coefficient
< 0
224 string
+= ' - {}*{}'.format(coefficient
, symbol
)
226 constant
= self
.constant
227 if constant
!= 0 and i
== 0:
228 string
+= '{}'.format(constant
)
230 string
+= ' + {}'.format(constant
)
233 string
+= ' - {}'.format(constant
)
238 def _parenstr(self
, always
=False):
240 if not always
and (self
.isconstant() or self
.issymbol()):
243 return '({})'.format(string
)
246 string
= '{}({{'.format(self
.__class
__.__name
__)
247 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients
):
250 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
251 string
+= '}}, {!r})'.format(self
.constant
)
255 def fromstring(cls
, string
):
256 raise NotImplementedError
259 def __eq__(self
, other
):
261 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
262 return isinstance(other
, Expression
) and \
263 self
._coefficients
== other
._coefficients
and \
264 self
.constant
== other
.constant
267 return hash((self
._coefficients
, self
._constant
))
270 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
271 [value
.denominator
for value
in self
.values()])
275 def _eq(self
, other
):
276 return Polyhedron(equalities
=[(self
- other
)._canonify
()])
279 def __le__(self
, other
):
280 return Polyhedron(inequalities
=[(other
- self
)._canonify
()])
283 def __lt__(self
, other
):
284 return Polyhedron(inequalities
=[(other
- self
)._canonify
() - 1])
287 def __ge__(self
, other
):
288 return Polyhedron(inequalities
=[(self
- other
)._canonify
()])
291 def __gt__(self
, other
):
292 return Polyhedron(inequalities
=[(self
- other
)._canonify
() - 1])
295 def constant(numerator
=0, denominator
=None):
296 if denominator
is None and isinstance(numerator
, numbers
.Rational
):
297 return Expression(constant
=numerator
)
299 return Expression(constant
=Fraction(numerator
, denominator
))
302 if not isinstance(name
, str):
303 raise TypeError('name must be a string')
304 return Expression(coefficients
={name
: 1})
307 if isinstance(names
, str):
308 names
= names
.replace(',', ' ').split()
309 return (symbol(name
) for name
in names
)
312 @_polymorphic_operator
316 @_polymorphic_operator
320 @_polymorphic_operator
324 @_polymorphic_operator
328 @_polymorphic_operator
335 This class implements polyhedrons.
338 def __new__(cls
, equalities
=None, inequalities
=None):
339 if isinstance(equalities
, str):
340 if inequalities
is not None:
341 raise TypeError('too many arguments')
342 return cls
.fromstring(equalities
)
343 self
= super().__new
__(cls
)
344 self
._equalities
= []
345 if equalities
is not None:
346 for constraint
in equalities
:
347 for value
in constraint
.values():
348 if value
.denominator
!= 1:
349 raise TypeError('non-integer constraint: '
350 '{} == 0'.format(constraint
))
351 self
._equalities
.append(constraint
)
352 self
._inequalities
= []
353 if inequalities
is not None:
354 for constraint
in inequalities
:
355 for value
in constraint
.values():
356 if value
.denominator
!= 1:
357 raise TypeError('non-integer constraint: '
358 '{} <= 0'.format(constraint
))
359 self
._inequalities
.append(constraint
)
360 self
._bset
= self
._to
_isl
()
362 #put this here just to test from isl method
363 #from_isl = self.from_isl(self._bset)
369 def equalities(self
):
370 yield from self
._equalities
373 def inequalities(self
):
374 yield from self
._inequalities
378 return self
._constant
380 def isconstant(self
):
381 return len(self
._coefficients
) == 0
384 return bool(libisl
.isl_basic_set_is_empty(self
._bset
))
386 def constraints(self
):
387 yield from self
.equalities
388 yield from self
.inequalities
392 for constraint
in self
.constraints():
393 s
.update(constraint
.symbols())
398 return len(self
.symbols())
401 # return false if the polyhedron is empty, true otherwise
402 if self
._equalities
or self
._inequalities
:
407 def __contains__(self
, value
):
408 # is the value in the polyhedron?
409 raise NotImplementedError
411 def __eq__(self
, other
):
412 raise NotImplementedError
417 def isuniverse(self
):
418 return self
== universe
420 def isdisjoint(self
, other
):
421 # return true if the polyhedron has no elements in common with other
422 raise NotImplementedError
424 def issubset(self
, other
):
425 raise NotImplementedError
427 def __le__(self
, other
):
428 return self
.issubset(other
)
430 def __lt__(self
, other
):
431 raise NotImplementedError
433 def issuperset(self
, other
):
434 # test whether every element in other is in the polyhedron
436 if value
== self
.constraints():
440 raise NotImplementedError
442 def __ge__(self
, other
):
443 return self
.issuperset(other
)
445 def __gt__(self
, other
):
446 raise NotImplementedError
448 def union(self
, *others
):
449 # return a new polyhedron with elements from the polyhedron and all
450 # others (convex union)
451 raise NotImplementedError
453 def __or__(self
, other
):
454 return self
.union(other
)
456 def intersection(self
, *others
):
457 # return a new polyhedron with elements common to the polyhedron and all
459 # a poor man's implementation could be:
460 # equalities = list(self.equalities)
461 # inequalities = list(self.inequalities)
462 # for other in others:
463 # equalities.extend(other.equalities)
464 # inequalities.extend(other.inequalities)
465 # return self.__class__(equalities, inequalities)
466 raise NotImplementedError
468 def __and__(self
, other
):
469 return self
.intersection(other
)
471 def difference(self
, *others
):
472 # return a new polyhedron with elements in the polyhedron that are not
474 raise NotImplementedError
476 def __sub__(self
, other
):
477 return self
.difference(other
)
481 for constraint
in self
.equalities
:
482 constraints
.append('{} == 0'.format(constraint
))
483 for constraint
in self
.inequalities
:
484 constraints
.append('{} >= 0'.format(constraint
))
485 return '{{{}}}'.format(', '.join(constraints
))
488 equalities
= list(self
.equalities
)
489 inequalities
= list(self
.inequalities
)
490 return '{}(equalities={!r}, inequalities={!r})' \
491 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
494 def fromstring(cls
, string
):
495 raise NotImplementedError
497 def _symbolunion(self
, *others
):
498 symbols
= set(self
.symbols())
500 symbols
.update(other
.symbols())
501 return sorted(symbols
)
503 def _to_isl(self
, symbols
=None):
505 symbols
= self
.symbols()
506 num_coefficients
= len(symbols
)
508 space
= libisl
.isl_space_set_alloc(ctx
, 0, num_coefficients
)
509 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
510 ls
= libisl
.isl_local_space_from_space(space
)
511 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
512 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
513 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
514 if list(self
.equalities
): #check if any equalities exist
515 for eq
in self
.equalities
:
516 coeff_eq
= dict(eq
.coefficients
)
519 ceq
= libisl
.isl_constraint_set_constant_si(ceq
, value
)
521 num
= coeff_eq
.get(eq
)
522 iden
= symbols
.index(eq
)
523 ceq
= libisl
.isl_constraint_set_coefficient_si(ceq
, islhelper
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
524 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
525 if list(self
.inequalities
): #check if any inequalities exist
526 for ineq
in self
.inequalities
:
527 coeff_in
= dict(ineq
.coefficients
)
529 value
= ineq
.constant
530 cin
= libisl
.isl_constraint_set_constant_si(cin
, value
)
531 for ineq
in coeff_in
:
532 num
= coeff_in
.get(ineq
)
533 iden
= symbols
.index(ineq
)
534 cin
= libisl
.isl_constraint_set_coefficient_si(cin
, islhelper
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
535 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
536 ip
= libisl
.isl_printer_to_str(ctx
) #create string printer
537 ip
= libisl
.isl_printer_print_basic_set(ip
, bset
) #print basic set to printer
538 string
= libisl
.isl_printer_get_str(ip
) #get string from printer
539 string
= str(string
.decode())
543 def from_isl(self
, bset
):
544 '''takes basic set in isl form and puts back into python version of polyhedron
545 isl example code gives isl form as:
546 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
547 our printer is giving form as:
548 b'{ [i0] : 1 = 0 }' '''
551 constraints
= libisl
.isl_basic_set_equalities_matrix(bset
, 3)
552 elif self
._inequalities
:
553 constraints
= libisl
.isl_basic_set_inequalities_matrix(bset
, 3)
557 empty
= None #eq(0,1)
558 universe
= None #Polyhedron()
560 if __name__
== '__main__':
561 ex1
= Expression(coefficients
={'a': 1, 'x': 2}, constant
=2)
562 ex2
= Expression(coefficients
={'a': 3 , 'b': 2}, constant
=3)
563 p
= Polyhedron(inequalities
=[ex1
, ex2
])
564 #p = eq(ex2, 0)# 2a+4 = 0, in fact 6a+3 = 0
567 #universe = Polyhedron()