2 VERY MESSY, made notes on where I will change things
7 import ctypes
, ctypes
.util
10 from fractions
import Fraction
, gcd
12 libisl
= ctypes
.CDLL(ctypes
.util
.find_library('isl'))
14 libisl
.isl_printer_get_str
.restype
= ctypes
.c_char_p
18 'constant', 'symbol', 'symbols',
19 'eq', 'le', 'lt', 'ge', 'gt',
25 _CONTEXT
= isl
.Context()
27 def _polymorphic_method(func
):
28 @functools.wraps(func
)
30 if isinstance(b
, Expression
):
32 if isinstance(b
, numbers
.Rational
):
38 def _polymorphic_operator(func
):
39 # A polymorphic operator should call a polymorphic method, hence we just
40 # have to test the left operand.
41 @functools.wraps(func
)
43 if isinstance(a
, numbers
.Rational
):
46 elif isinstance(a
, Expression
):
48 raise TypeError('arguments must be linear expressions')
54 This class implements linear expressions.
57 def __new__(cls
, coefficients
=None, constant
=0):
58 if isinstance(coefficients
, str):
60 raise TypeError('too many arguments')
61 return cls
.fromstring(coefficients
)
62 self
= super().__new
__(cls
)
63 self
._coefficients
= {}
64 if isinstance(coefficients
, dict):
65 coefficients
= coefficients
.items()
66 if coefficients
is not None:
67 for symbol
, coefficient
in coefficients
:
68 if isinstance(symbol
, Expression
) and symbol
.issymbol():
70 elif not isinstance(symbol
, str):
71 raise TypeError('symbols must be strings')
72 if not isinstance(coefficient
, numbers
.Rational
):
73 raise TypeError('coefficients must be rational numbers')
75 self
._coefficients
[symbol
] = coefficient
76 if not isinstance(constant
, numbers
.Rational
):
77 raise TypeError('constant must be a rational number')
78 self
._constant
= constant
83 yield from sorted(self
._coefficients
)
87 return len(list(self
.symbols()))
89 def coefficient(self
, symbol
):
90 if isinstance(symbol
, Expression
) and symbol
.issymbol():
92 elif not isinstance(symbol
, str):
93 raise TypeError('symbol must be a string')
95 return self
._coefficients
[symbol
]
99 __getitem__
= coefficient
101 def coefficients(self
):
102 for symbol
in self
.symbols():
103 yield symbol
, self
.coefficient(symbol
)
107 return self
._constant
109 def isconstant(self
):
110 return len(self
._coefficients
) == 0
113 for symbol
in self
.symbols():
114 yield self
.coefficient(symbol
)
117 def values_int(self
):
118 for symbol
in self
.symbols():
119 return self
.coefficient(symbol
)
120 return int(self
.constant
)
124 if not self
.issymbol():
125 raise ValueError('not a symbol: {}'.format(self
))
126 for symbol
in self
.symbols():
130 return len(self
._coefficients
) == 1 and self
._constant
== 0
133 return (not self
.isconstant()) or bool(self
.constant
)
142 def __add__(self
, other
):
143 coefficients
= dict(self
.coefficients())
144 for symbol
, coefficient
in other
.coefficients():
145 if symbol
in coefficients
:
146 coefficients
[symbol
] += coefficient
148 coefficients
[symbol
] = coefficient
149 constant
= self
.constant
+ other
.constant
150 return Expression(coefficients
, constant
)
155 def __sub__(self
, other
):
156 coefficients
= dict(self
.coefficients())
157 for symbol
, coefficient
in other
.coefficients():
158 if symbol
in coefficients
:
159 coefficients
[symbol
] -= coefficient
161 coefficients
[symbol
] = -coefficient
162 constant
= self
.constant
- other
.constant
163 return Expression(coefficients
, constant
)
168 def __mul__(self
, other
):
169 if other
.isconstant():
170 coefficients
= dict(self
.coefficients())
171 for symbol
in coefficients
:
172 coefficients
[symbol
] *= other
.constant
173 constant
= self
.constant
* other
.constant
174 return Expression(coefficients
, constant
)
175 if isinstance(other
, Expression
) and not self
.isconstant():
176 raise ValueError('non-linear expression: '
177 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
178 return NotImplemented
183 def __truediv__(self
, other
):
184 if other
.isconstant():
185 coefficients
= dict(self
.coefficients())
186 for symbol
in coefficients
:
187 coefficients
[symbol
] = \
188 Fraction(coefficients
[symbol
], other
.constant
)
189 constant
= Fraction(self
.constant
, other
.constant
)
190 return Expression(coefficients
, constant
)
191 if isinstance(other
, Expression
):
192 raise ValueError('non-linear expression: '
193 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
194 return NotImplemented
196 def __rtruediv__(self
, other
):
197 if isinstance(other
, self
):
198 if self
.isconstant():
199 constant
= Fraction(other
, self
.constant
)
200 return Expression(constant
=constant
)
202 raise ValueError('non-linear expression: '
203 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
204 return NotImplemented
208 symbols
= sorted(self
.symbols())
210 for symbol
in symbols
:
211 coefficient
= self
[symbol
]
216 string
+= ' + {}'.format(symbol
)
217 elif coefficient
== -1:
219 string
+= '-{}'.format(symbol
)
221 string
+= ' - {}'.format(symbol
)
224 string
+= '{}*{}'.format(coefficient
, symbol
)
225 elif coefficient
> 0:
226 string
+= ' + {}*{}'.format(coefficient
, symbol
)
228 assert coefficient
< 0
230 string
+= ' - {}*{}'.format(coefficient
, symbol
)
232 constant
= self
.constant
233 if constant
!= 0 and i
== 0:
234 string
+= '{}'.format(constant
)
236 string
+= ' + {}'.format(constant
)
239 string
+= ' - {}'.format(constant
)
244 def _parenstr(self
, always
=False):
246 if not always
and (self
.isconstant() or self
.issymbol()):
249 return '({})'.format(string
)
252 string
= '{}({{'.format(self
.__class
__.__name
__)
253 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
256 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
257 string
+= '}}, {!r})'.format(self
.constant
)
261 def fromstring(cls
, string
):
262 raise NotImplementedError
265 def __eq__(self
, other
):
267 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
268 return isinstance(other
, Expression
) and \
269 self
._coefficients
== other
._coefficients
and \
270 self
.constant
== other
.constant
273 return hash((self
._coefficients
, self
._constant
))
276 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
277 [value
.denominator
for value
in self
.values()])
281 def _eq(self
, other
):
282 return Polyhedron(equalities
=[(self
- other
)._canonify
()])
285 def __le__(self
, other
):
286 return Polyhedron(inequalities
=[(self
- other
)._canonify
()])
289 def __lt__(self
, other
):
290 return Polyhedron(inequalities
=[(self
- other
)._canonify
() + 1])
293 def __ge__(self
, other
):
294 return Polyhedron(inequalities
=[(other
- self
)._canonify
()])
297 def __gt__(self
, other
):
298 return Polyhedron(inequalities
=[(other
- self
)._canonify
() + 1])
301 def constant(numerator
=0, denominator
=None):
302 if denominator
is None and isinstance(numerator
, numbers
.Rational
):
303 return Expression(constant
=numerator
)
305 return Expression(constant
=Fraction(numerator
, denominator
))
308 if not isinstance(name
, str):
309 raise TypeError('name must be a string')
310 return Expression(coefficients
={name
: 1})
313 if isinstance(names
, str):
314 names
= names
.replace(',', ' ').split()
315 return (symbol(name
) for name
in names
)
318 @_polymorphic_operator
322 @_polymorphic_operator
326 @_polymorphic_operator
330 @_polymorphic_operator
334 @_polymorphic_operator
341 This class implements polyhedrons.
344 def __new__(cls
, equalities
=None, inequalities
=None):
345 if isinstance(equalities
, str):
346 if inequalities
is not None:
347 raise TypeError('too many arguments')
348 return cls
.fromstring(equalities
)
349 self
= super().__new
__(cls
)
350 self
._equalities
= []
351 if equalities
is not None:
352 for constraint
in equalities
:
353 for value
in constraint
.values():
354 if value
.denominator
!= 1:
355 raise TypeError('non-integer constraint: '
356 '{} == 0'.format(constraint
))
357 self
._equalities
.append(constraint
)
358 self
._inequalities
= []
359 if inequalities
is not None:
360 for constraint
in inequalities
:
361 for value
in constraint
.values():
362 if value
.denominator
!= 1:
363 raise TypeError('non-integer constraint: '
364 '{} <= 0'.format(constraint
))
365 self
._inequalities
.append(constraint
)
366 print('in polyhedron')
367 #print(self.constraints())
368 self
._bset
= self
.to_isl()
374 def equalities(self
):
375 yield from self
._equalities
378 def inequalities(self
):
379 yield from self
._inequalities
383 return self
._constant
385 def isconstant(self
):
386 return len(self
._coefficients
) == 0
390 return bool(libisl
.isl_basic_set_is_empty(self
._bset
))
392 def constraints(self
):
393 yield from self
.equalities
394 yield from self
.inequalities
399 for constraint
in self
.constraints():
400 s
.update(constraint
.symbols
)
403 def symbol_count(self
):
405 for constraint
in self
.constraints():
406 s
.append(constraint
.symbols
)
411 return len(self
.symbols())
414 # return false if the polyhedron is empty, true otherwise
415 if self
._equalities
or self
._inequalities
:
421 def __contains__(self
, value
):
422 # is the value in the polyhedron?
423 raise NotImplementedError
425 def __eq__(self
, other
):
426 raise NotImplementedError
431 def isuniverse(self
):
432 return self
== universe
434 def isdisjoint(self
, other
):
435 # return true if the polyhedron has no elements in common with other
436 raise NotImplementedError
438 def issubset(self
, other
):
439 raise NotImplementedError
441 def __le__(self
, other
):
442 return self
.issubset(other
)
444 def __lt__(self
, other
):
445 raise NotImplementedError
447 def issuperset(self
, other
):
448 # test whether every element in other is in the polyhedron
450 if value
== self
.constraints():
454 raise NotImplementedError
456 def __ge__(self
, other
):
457 return self
.issuperset(other
)
459 def __gt__(self
, other
):
460 raise NotImplementedError
462 def union(self
, *others
):
463 # return a new polyhedron with elements from the polyhedron and all
464 # others (convex union)
465 raise NotImplementedError
467 def __or__(self
, other
):
468 return self
.union(other
)
470 def intersection(self
, *others
):
471 # return a new polyhedron with elements common to the polyhedron and all
473 # a poor man's implementation could be:
474 # equalities = list(self.equalities)
475 # inequalities = list(self.inequalities)
476 # for other in others:
477 # equalities.extend(other.equalities)
478 # inequalities.extend(other.inequalities)
479 # return self.__class__(equalities, inequalities)
480 raise NotImplementedError
482 def __and__(self
, other
):
483 return self
.intersection(other
)
485 def difference(self
, *others
):
486 # return a new polyhedron with elements in the polyhedron that are not
488 raise NotImplementedError
490 def __sub__(self
, other
):
491 return self
.difference(other
)
495 for constraint
in self
.equalities
:
496 constraints
.append('{} == 0'.format(constraint
))
497 for constraint
in self
.inequalities
:
498 constraints
.append('{} <= 0'.format(constraint
))
499 return '{{{}}}'.format(', '.join(constraints
))
502 equalities
= list(self
.equalities
)
503 inequalities
= list(self
.inequalities
)
504 return '{}(equalities={!r}, inequalities={!r})' \
505 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
508 def fromstring(cls
, string
):
509 raise NotImplementedError
512 space
= libisl
.isl_space_set_alloc(_CONTEXT
, 0, len(self
.symbol_count()))
513 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
514 copy
= libisl
.isl_basic_set_copy(bset
)
515 ls
= libisl
.isl_local_space_from_space(libisl
.isl_space_copy(space
))
516 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
517 for value
in self
.equalities
:
518 for value
in self
.equalities
:
519 #need method to get expression value
521 value
= self
._equalities
.method_get_value_from_expression()
522 ceq
= libisl
.isl_constraint_set_constant_val(ceq
, value
)
523 #ceq = libisl.isl_constraint_set_coefficient_si(ceq, libisl.isl_set_dim, self.symbols(), value)
525 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
526 for item in self.inequalities:
527 for item in self.inequalities:
528 if isinstance(item, int):
529 cin = libisl.isl_constraint_set_constant_si(cin, item)
531 cin = libisl.isl_constraint_set_coefficient_si(cin, libisl.isl_set_dim, self.symbols(), item)
533 bsetfinal
= libisl
.isl_basic_set_add_contraint(copy
, ceq
)
534 #bsetfinal = libisl.isl_basic_set_add_contraint(copy, cin)
535 string
= libisl
.isl_printer_print_basic_set(bsetfinal
)
542 universe
= Polyhedron()