values used for toisl(some expressions not correct)
[linpy.git] / pypol / linear.py
1 import ctypes, ctypes.util
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7 from pypol import isl
8 from pypol.isl import libisl
9
10
11 __all__ = [
12 'Expression', 'Constant', 'Symbol', 'symbols',
13 'eq', 'le', 'lt', 'ge', 'gt',
14 'Polyhedron',
15 'empty', 'universe'
16 ]
17
18
19 def _polymorphic_method(func):
20 @functools.wraps(func)
21 def wrapper(a, b):
22 if isinstance(b, Expression):
23 return func(a, b)
24 if isinstance(b, numbers.Rational):
25 b = Constant(b)
26 return func(a, b)
27 return NotImplemented
28 return wrapper
29
30 def _polymorphic_operator(func):
31 # A polymorphic operator should call a polymorphic method, hence we just
32 # have to test the left operand.
33 @functools.wraps(func)
34 def wrapper(a, b):
35 if isinstance(a, numbers.Rational):
36 a = Constant(a)
37 return func(a, b)
38 elif isinstance(a, Expression):
39 return func(a, b)
40 raise TypeError('arguments must be linear expressions')
41 return wrapper
42
43
44 _main_ctx = isl.Context()
45
46
47 class Expression:
48 """
49 This class implements linear expressions.
50 """
51
52 def __new__(cls, coefficients=None, constant=0):
53 if isinstance(coefficients, str):
54 if constant:
55 raise TypeError('too many arguments')
56 return cls.fromstring(coefficients)
57 if isinstance(coefficients, dict):
58 coefficients = coefficients.items()
59 if coefficients is None:
60 return Constant(constant)
61 coefficients = [(symbol, coefficient)
62 for symbol, coefficient in coefficients if coefficient != 0]
63 if len(coefficients) == 0:
64 return Constant(constant)
65 elif len(coefficients) == 1 and constant == 0:
66 symbol, coefficient = coefficients[0]
67 if coefficient == 1:
68 return Symbol(symbol)
69 self = object().__new__(cls)
70 self._coefficients = {}
71 for symbol, coefficient in coefficients:
72 if isinstance(symbol, Symbol):
73 symbol = str(symbol)
74 elif not isinstance(symbol, str):
75 raise TypeError('symbols must be strings or Symbol instances')
76 if isinstance(coefficient, Constant):
77 coefficient = coefficient.constant
78 if not isinstance(coefficient, numbers.Rational):
79 raise TypeError('coefficients must be rational numbers or Constant instances')
80 self._coefficients[symbol] = coefficient
81 if isinstance(constant, Constant):
82 constant = constant.constant
83 if not isinstance(constant, numbers.Rational):
84 raise TypeError('constant must be a rational number or a Constant instance')
85 self._constant = constant
86 self._symbols = tuple(sorted(self._coefficients))
87 self._dimension = len(self._symbols)
88 return self
89
90 @classmethod
91 def fromstring(cls, string):
92 raise NotImplementedError
93
94 @property
95 def symbols(self):
96 return self._symbols
97
98 @property
99 def dimension(self):
100 return self._dimension
101
102 def coefficient(self, symbol):
103 if isinstance(symbol, Symbol):
104 symbol = str(symbol)
105 elif not isinstance(symbol, str):
106 raise TypeError('symbol must be a string or a Symbol instance')
107 try:
108 return self._coefficients[symbol]
109 except KeyError:
110 return 0
111
112 __getitem__ = coefficient
113
114 def coefficients(self):
115 for symbol in self.symbols:
116 yield symbol, self.coefficient(symbol)
117
118 @property
119 def constant(self):
120 return self._constant
121
122 def isconstant(self):
123 return False
124
125 def values(self):
126 for symbol in self.symbols:
127 yield self.coefficient(symbol)
128 yield self.constant
129
130 @property
131 def symbol(self):
132 raise ValueError('not a symbol: {}'.format(self))
133
134 def issymbol(self):
135 return False
136
137 def __bool__(self):
138 return True
139
140 def __pos__(self):
141 return self
142
143 def __neg__(self):
144 return self * -1
145
146 @_polymorphic_method
147 def __add__(self, other):
148 coefficients = dict(self.coefficients())
149 for symbol, coefficient in other.coefficients():
150 if symbol in coefficients:
151 coefficients[symbol] += coefficient
152 else:
153 coefficients[symbol] = coefficient
154 constant = self.constant + other.constant
155 return Expression(coefficients, constant)
156
157 __radd__ = __add__
158
159 @_polymorphic_method
160 def __sub__(self, other):
161 coefficients = dict(self.coefficients())
162 for symbol, coefficient in other.coefficients():
163 if symbol in coefficients:
164 coefficients[symbol] -= coefficient
165 else:
166 coefficients[symbol] = -coefficient
167 constant = self.constant - other.constant
168 return Expression(coefficients, constant)
169
170 def __rsub__(self, other):
171 return -(self - other)
172
173 @_polymorphic_method
174 def __mul__(self, other):
175 if other.isconstant():
176 coefficients = dict(self.coefficients())
177 for symbol in coefficients:
178 coefficients[symbol] *= other.constant
179 constant = self.constant * other.constant
180 return Expression(coefficients, constant)
181 if isinstance(other, Expression) and not self.isconstant():
182 raise ValueError('non-linear expression: '
183 '{} * {}'.format(self._parenstr(), other._parenstr()))
184 return NotImplemented
185
186 __rmul__ = __mul__
187
188 @_polymorphic_method
189 def __truediv__(self, other):
190 if other.isconstant():
191 coefficients = dict(self.coefficients())
192 for symbol in coefficients:
193 coefficients[symbol] = \
194 Fraction(coefficients[symbol], other.constant)
195 constant = Fraction(self.constant, other.constant)
196 return Expression(coefficients, constant)
197 if isinstance(other, Expression):
198 raise ValueError('non-linear expression: '
199 '{} / {}'.format(self._parenstr(), other._parenstr()))
200 return NotImplemented
201
202 def __rtruediv__(self, other):
203 if isinstance(other, self):
204 if self.isconstant():
205 constant = Fraction(other, self.constant)
206 return Expression(constant=constant)
207 else:
208 raise ValueError('non-linear expression: '
209 '{} / {}'.format(other._parenstr(), self._parenstr()))
210 return NotImplemented
211
212 def __str__(self):
213 string = ''
214 i = 0
215 for symbol in self.symbols:
216 coefficient = self.coefficient(symbol)
217 if coefficient == 1:
218 if i == 0:
219 string += symbol
220 else:
221 string += ' + {}'.format(symbol)
222 elif coefficient == -1:
223 if i == 0:
224 string += '-{}'.format(symbol)
225 else:
226 string += ' - {}'.format(symbol)
227 else:
228 if i == 0:
229 string += '{}*{}'.format(coefficient, symbol)
230 elif coefficient > 0:
231 string += ' + {}*{}'.format(coefficient, symbol)
232 else:
233 assert coefficient < 0
234 coefficient *= -1
235 string += ' - {}*{}'.format(coefficient, symbol)
236 i += 1
237 constant = self.constant
238 if constant != 0 and i == 0:
239 string += '{}'.format(constant)
240 elif constant > 0:
241 string += ' + {}'.format(constant)
242 elif constant < 0:
243 constant *= -1
244 string += ' - {}'.format(constant)
245 if string == '':
246 string = '0'
247 return string
248
249 def _parenstr(self, always=False):
250 string = str(self)
251 if not always and (self.isconstant() or self.issymbol()):
252 return string
253 else:
254 return '({})'.format(string)
255
256 def __repr__(self):
257 string = '{}({{'.format(self.__class__.__name__)
258 for i, (symbol, coefficient) in enumerate(self.coefficients()):
259 if i != 0:
260 string += ', '
261 string += '{!r}: {!r}'.format(symbol, coefficient)
262 string += '}}, {!r})'.format(self.constant)
263 return string
264
265 @_polymorphic_method
266 def __eq__(self, other):
267 # "normal" equality
268 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
269 return isinstance(other, Expression) and \
270 self._coefficients == other._coefficients and \
271 self.constant == other.constant
272
273 def __hash__(self):
274 return hash((tuple(sorted(self._coefficients.items())), self._constant))
275
276 def _toint(self):
277 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
278 [value.denominator for value in self.values()])
279 return self * lcm
280
281 @_polymorphic_method
282 def _eq(self, other):
283 return Polyhedron(equalities=[(self - other)._toint()])
284
285 @_polymorphic_method
286 def __le__(self, other):
287 return Polyhedron(inequalities=[(other - self)._toint()])
288
289 @_polymorphic_method
290 def __lt__(self, other):
291 return Polyhedron(inequalities=[(other - self)._toint() - 1])
292
293 @_polymorphic_method
294 def __ge__(self, other):
295 return Polyhedron(inequalities=[(self - other)._toint()])
296
297 @_polymorphic_method
298 def __gt__(self, other):
299 return Polyhedron(inequalities=[(self - other)._toint() - 1])
300
301
302 class Constant(Expression):
303
304 def __new__(cls, numerator=0, denominator=None):
305 self = object().__new__(cls)
306 if denominator is None:
307 if isinstance(numerator, numbers.Rational):
308 self._constant = numerator
309 elif isinstance(numerator, Constant):
310 self._constant = numerator.constant
311 else:
312 raise TypeError('constant must be a rational number or a Constant instance')
313 else:
314 self._constant = Fraction(numerator, denominator)
315 self._coefficients = {}
316 self._symbols = ()
317 self._dimension = 0
318 return self
319
320 def isconstant(self):
321 return True
322
323 def __bool__(self):
324 return bool(self.constant)
325
326 def __repr__(self):
327 return '{}({!r})'.format(self.__class__.__name__, self._constant)
328
329
330 class Symbol(Expression):
331
332 def __new__(cls, name):
333 if isinstance(name, Symbol):
334 name = name.symbol
335 elif not isinstance(name, str):
336 raise TypeError('name must be a string or a Symbol instance')
337 self = object().__new__(cls)
338 self._coefficients = {name: 1}
339 self._constant = 0
340 self._symbols = tuple(name)
341 self._symbol = name
342 self._dimension = 1
343 return self
344
345 @property
346 def symbol(self):
347 return self._symbol
348
349 def issymbol(self):
350 return True
351
352 def __repr__(self):
353 return '{}({!r})'.format(self.__class__.__name__, self._symbol)
354
355 def symbols(names):
356 if isinstance(names, str):
357 names = names.replace(',', ' ').split()
358 return (Symbol(name) for name in names)
359
360
361 @_polymorphic_operator
362 def eq(a, b):
363 return a.__eq__(b)
364
365 @_polymorphic_operator
366 def le(a, b):
367 return a.__le__(b)
368
369 @_polymorphic_operator
370 def lt(a, b):
371 return a.__lt__(b)
372
373 @_polymorphic_operator
374 def ge(a, b):
375 return a.__ge__(b)
376
377 @_polymorphic_operator
378 def gt(a, b):
379 return a.__gt__(b)
380
381
382 class Polyhedron:
383 """
384 This class implements polyhedrons.
385 """
386
387 def __new__(cls, equalities=None, inequalities=None):
388 if isinstance(equalities, str):
389 if inequalities is not None:
390 raise TypeError('too many arguments')
391 return cls.fromstring(equalities)
392 self = super().__new__(cls)
393 self._equalities = []
394 if equalities is not None:
395 for constraint in equalities:
396 for value in constraint.values():
397 if value.denominator != 1:
398 raise TypeError('non-integer constraint: '
399 '{} == 0'.format(constraint))
400 self._equalities.append(constraint)
401 self._equalities = tuple(self._equalities)
402 self._inequalities = []
403 if inequalities is not None:
404 for constraint in inequalities:
405 for value in constraint.values():
406 if value.denominator != 1:
407 raise TypeError('non-integer constraint: '
408 '{} <= 0'.format(constraint))
409 self._inequalities.append(constraint)
410 self._inequalities = tuple(self._inequalities)
411 self._constraints = self._equalities + self._inequalities
412 self._symbols = set()
413 for constraint in self._constraints:
414 self.symbols.update(constraint.symbols)
415 self._symbols = tuple(sorted(self._symbols))
416 return self
417
418 @classmethod
419 def fromstring(cls, string):
420 raise NotImplementedError
421
422 @property
423 def equalities(self):
424 return self._equalities
425
426 @property
427 def inequalities(self):
428 return self._inequalities
429
430 @property
431 def constraints(self):
432 return self._constraints
433
434 @property
435 def symbols(self):
436 return self._symbols
437
438 @property
439 def dimension(self):
440 return len(self.symbols)
441
442 def __bool__(self):
443 return not self.is_empty()
444
445 def __contains__(self, value):
446 # is the value in the polyhedron?
447 raise NotImplementedError
448
449 def __eq__(self, other):
450 # works correctly when symbols is not passed
451 # should be equal if values are the same even if symbols are different
452 bset = self._toisl()
453 other = other._toisl()
454 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
455
456 def isempty(self):
457 bset = self._toisl()
458 return bool(libisl.isl_basic_set_is_empty(bset))
459
460 def isuniverse(self):
461 bset = self._toisl()
462 return bool(libisl.isl_basic_set_is_universe(bset))
463
464 def isdisjoint(self, other):
465 # return true if the polyhedron has no elements in common with other
466 #symbols = self._symbolunion(other)
467 bset = self._toisl()
468 other = other._toisl()
469 return bool(libisl.isl_set_is_disjoint(bset, other))
470
471 def issubset(self, other):
472 # check if self(bset) is a subset of other
473 symbols = self._symbolunion(other)
474 bset = self._toisl(symbols)
475 other = other._toisl(symbols)
476 return bool(libisl.isl_set_is_strict_subset(other, bset))
477
478 def __le__(self, other):
479 return self.issubset(other)
480
481 def __lt__(self, other):
482 symbols = self._symbolunion(other)
483 bset = self._toisl(symbols)
484 other = other._toisl(symbols)
485 return bool(libisl.isl_set_is_strict_subset(other, bset))
486
487 def issuperset(self, other):
488 # test whether every element in other is in the polyhedron
489 raise NotImplementedError
490
491 def __ge__(self, other):
492 return self.issuperset(other)
493
494 def __gt__(self, other):
495 symbols = self._symbolunion(other)
496 bset = self._toisl(symbols)
497 other = other._toisl(symbols)
498 bool(libisl.isl_set_is_strict_subset(other, bset))
499 raise NotImplementedError
500
501 def union(self, *others):
502 # return a new polyhedron with elements from the polyhedron and all
503 # others (convex union)
504 raise NotImplementedError
505
506 def __or__(self, other):
507 return self.union(other)
508
509 def intersection(self, *others):
510 # return a new polyhedron with elements common to the polyhedron and all
511 # others
512 # a poor man's implementation could be:
513 # equalities = list(self.equalities)
514 # inequalities = list(self.inequalities)
515 # for other in others:
516 # equalities.extend(other.equalities)
517 # inequalities.extend(other.inequalities)
518 # return self.__class__(equalities, inequalities)
519 raise NotImplementedError
520
521 def __and__(self, other):
522 return self.intersection(other)
523
524 def difference(self, other):
525 # return a new polyhedron with elements in the polyhedron that are not in the other
526 symbols = self._symbolunion(other)
527 bset = self._toisl(symbols)
528 other = other._toisl(symbols)
529 difference = libisl.isl_set_subtract(bset, other)
530 return difference
531
532
533 def __sub__(self, other):
534 return self.difference(other)
535
536 def __str__(self):
537 constraints = []
538 for constraint in self.equalities:
539 constraints.append('{} == 0'.format(constraint))
540 for constraint in self.inequalities:
541 constraints.append('{} >= 0'.format(constraint))
542 return '{{{}}}'.format(', '.join(constraints))
543
544 def __repr__(self):
545 equalities = list(self.equalities)
546 inequalities = list(self.inequalities)
547 return '{}(equalities={!r}, inequalities={!r})' \
548 ''.format(self.__class__.__name__, equalities, inequalities)
549
550 def _symbolunion(self, *others):
551 symbols = set(self.symbols)
552 for other in others:
553 symbols.update(other.symbols)
554 return sorted(symbols)
555
556 def _toisl(self, symbols=None):
557 if symbols is None:
558 symbols = self.symbols
559 num_coefficients = len(symbols)
560 space = libisl.isl_space_set_alloc(_main_ctx, 0, num_coefficients)
561 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
562 ls = libisl.isl_local_space_from_space(space)
563 #if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set
564 for eq in self.equalities:
565 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
566 coeff_eq = dict(eq.coefficients())
567 if eq.constant:
568 value = str(eq.constant).encode()
569 val = libisl.isl_val_read_from_str(_main_ctx, value)
570 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
571 for eq in coeff_eq:
572 number = str(coeff_eq.get(eq)).encode()
573 num = libisl.isl_val_read_from_str(_main_ctx, number)
574 iden = symbols.index(eq)
575 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
576 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
577 for ineq in self.inequalities:
578 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
579 coeff_in = dict(ineq.coefficients())
580 if ineq.constant:
581 value = str(ineq.constant).encode()
582 val = libisl.isl_val_read_from_str(_main_ctx, value)
583 cin = libisl.isl_constraint_set_constant_val(cin, val)
584 for ineq in coeff_in:
585 number = str(coeff_in.get(ineq)).encode()
586 num = libisl.isl_val_read_from_str(_main_ctx, number)
587 iden = symbols.index(ineq)
588 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
589 bset = libisl.isl_basic_set_add_constraint(bset, cin)
590 bset = isl.BasicSet(bset)
591 return bset
592
593 @classmethod
594 def _fromisl(cls, bset):
595 raise NotImplementedError
596 equalities = ...
597 inequalities = ...
598 return cls(equalities, inequalities)
599 '''takes basic set in isl form and puts back into python version of polyhedron
600 isl example code gives isl form as:
601 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
602 our printer is giving form as:
603 { [i0, i1] : 2i1 >= -2 - i0 } '''
604
605 empty = None #eq(0,1)
606 universe = None #Polyhedron()
607
608 if __name__ == '__main__':
609 ex1 = Expression(coefficients={'a': 6, 'b': 6}, constant= 3) #this is the expression that does not work (even without adding values)
610 ex2 = Expression(coefficients={'x': 4, 'y': 2}, constant= 3)
611 p = Polyhedron(equalities=[ex2])
612 p2 = Polyhedron(equalities=[ex2])
613 print(p._toisl()) # checking is values works for toisl