3cd76335e443cba2cda1c711bb8b1c5f6ac50cbd
[linpy.git] / pypol / linear.py
1 import functools
2 import numbers
3
4 from fractions import Fraction, gcd
5
6 from pypol import isl
7 from pypol.isl import libisl
8
9
10 __all__ = [
11 'Expression', 'Constant', 'Symbol', 'symbols',
12 'eq', 'le', 'lt', 'ge', 'gt',
13 'Polyhedron',
14 'Empty', 'Universe'
15 ]
16
17
18 def _polymorphic_method(func):
19 @functools.wraps(func)
20 def wrapper(a, b):
21 if isinstance(b, Expression):
22 return func(a, b)
23 if isinstance(b, numbers.Rational):
24 b = Constant(b)
25 return func(a, b)
26 return NotImplemented
27 return wrapper
28
29 def _polymorphic_operator(func):
30 # A polymorphic operator should call a polymorphic method, hence we just
31 # have to test the left operand.
32 @functools.wraps(func)
33 def wrapper(a, b):
34 if isinstance(a, numbers.Rational):
35 a = Constant(a)
36 return func(a, b)
37 elif isinstance(a, Expression):
38 return func(a, b)
39 raise TypeError('arguments must be linear expressions')
40 return wrapper
41
42
43 _main_ctx = isl.Context()
44
45
46 class Expression:
47 """
48 This class implements linear expressions.
49 """
50
51 def __new__(cls, coefficients=None, constant=0):
52 if isinstance(coefficients, str):
53 if constant:
54 raise TypeError('too many arguments')
55 return cls.fromstring(coefficients)
56 if isinstance(coefficients, dict):
57 coefficients = coefficients.items()
58 if coefficients is None:
59 return Constant(constant)
60 coefficients = [(symbol, coefficient)
61 for symbol, coefficient in coefficients if coefficient != 0]
62 if len(coefficients) == 0:
63 return Constant(constant)
64 elif len(coefficients) == 1 and constant == 0:
65 symbol, coefficient = coefficients[0]
66 if coefficient == 1:
67 return Symbol(symbol)
68 self = object().__new__(cls)
69 self._coefficients = {}
70 for symbol, coefficient in coefficients:
71 if isinstance(symbol, Symbol):
72 symbol = str(symbol)
73 elif not isinstance(symbol, str):
74 raise TypeError('symbols must be strings or Symbol instances')
75 if isinstance(coefficient, Constant):
76 coefficient = coefficient.constant
77 if not isinstance(coefficient, numbers.Rational):
78 raise TypeError('coefficients must be rational numbers or Constant instances')
79 self._coefficients[symbol] = coefficient
80 if isinstance(constant, Constant):
81 constant = constant.constant
82 if not isinstance(constant, numbers.Rational):
83 raise TypeError('constant must be a rational number or a Constant instance')
84 self._constant = constant
85 self._symbols = tuple(sorted(self._coefficients))
86 self._dimension = len(self._symbols)
87 return self
88
89 @classmethod
90 def fromstring(cls, string):
91 raise NotImplementedError
92
93 @property
94 def symbols(self):
95 return self._symbols
96
97 @property
98 def dimension(self):
99 return self._dimension
100
101 def coefficient(self, symbol):
102 if isinstance(symbol, Symbol):
103 symbol = str(symbol)
104 elif not isinstance(symbol, str):
105 raise TypeError('symbol must be a string or a Symbol instance')
106 try:
107 return self._coefficients[symbol]
108 except KeyError:
109 return 0
110
111 __getitem__ = coefficient
112
113 def coefficients(self):
114 for symbol in self.symbols:
115 yield symbol, self.coefficient(symbol)
116
117 @property
118 def constant(self):
119 return self._constant
120
121 def isconstant(self):
122 return False
123
124 def values(self):
125 for symbol in self.symbols:
126 yield self.coefficient(symbol)
127 yield self.constant
128
129 @property
130 def symbol(self):
131 raise ValueError('not a symbol: {}'.format(self))
132
133 def issymbol(self):
134 return False
135
136 def __bool__(self):
137 return True
138
139 def __pos__(self):
140 return self
141
142 def __neg__(self):
143 return self * -1
144
145 @_polymorphic_method
146 def __add__(self, other):
147 coefficients = dict(self.coefficients())
148 for symbol, coefficient in other.coefficients():
149 if symbol in coefficients:
150 coefficients[symbol] += coefficient
151 else:
152 coefficients[symbol] = coefficient
153 constant = self.constant + other.constant
154 return Expression(coefficients, constant)
155
156 __radd__ = __add__
157
158 @_polymorphic_method
159 def __sub__(self, other):
160 coefficients = dict(self.coefficients())
161 for symbol, coefficient in other.coefficients():
162 if symbol in coefficients:
163 coefficients[symbol] -= coefficient
164 else:
165 coefficients[symbol] = -coefficient
166 constant = self.constant - other.constant
167 return Expression(coefficients, constant)
168
169 def __rsub__(self, other):
170 return -(self - other)
171
172 @_polymorphic_method
173 def __mul__(self, other):
174 if other.isconstant():
175 coefficients = dict(self.coefficients())
176 for symbol in coefficients:
177 coefficients[symbol] *= other.constant
178 constant = self.constant * other.constant
179 return Expression(coefficients, constant)
180 if isinstance(other, Expression) and not self.isconstant():
181 raise ValueError('non-linear expression: '
182 '{} * {}'.format(self._parenstr(), other._parenstr()))
183 return NotImplemented
184
185 __rmul__ = __mul__
186
187 @_polymorphic_method
188 def __truediv__(self, other):
189 if other.isconstant():
190 coefficients = dict(self.coefficients())
191 for symbol in coefficients:
192 coefficients[symbol] = \
193 Fraction(coefficients[symbol], other.constant)
194 constant = Fraction(self.constant, other.constant)
195 return Expression(coefficients, constant)
196 if isinstance(other, Expression):
197 raise ValueError('non-linear expression: '
198 '{} / {}'.format(self._parenstr(), other._parenstr()))
199 return NotImplemented
200
201 def __rtruediv__(self, other):
202 if isinstance(other, self):
203 if self.isconstant():
204 constant = Fraction(other, self.constant)
205 return Expression(constant=constant)
206 else:
207 raise ValueError('non-linear expression: '
208 '{} / {}'.format(other._parenstr(), self._parenstr()))
209 return NotImplemented
210
211 def __str__(self):
212 string = ''
213 i = 0
214 for symbol in self.symbols:
215 coefficient = self.coefficient(symbol)
216 if coefficient == 1:
217 if i == 0:
218 string += symbol
219 else:
220 string += ' + {}'.format(symbol)
221 elif coefficient == -1:
222 if i == 0:
223 string += '-{}'.format(symbol)
224 else:
225 string += ' - {}'.format(symbol)
226 else:
227 if i == 0:
228 string += '{}*{}'.format(coefficient, symbol)
229 elif coefficient > 0:
230 string += ' + {}*{}'.format(coefficient, symbol)
231 else:
232 assert coefficient < 0
233 coefficient *= -1
234 string += ' - {}*{}'.format(coefficient, symbol)
235 i += 1
236 constant = self.constant
237 if constant != 0 and i == 0:
238 string += '{}'.format(constant)
239 elif constant > 0:
240 string += ' + {}'.format(constant)
241 elif constant < 0:
242 constant *= -1
243 string += ' - {}'.format(constant)
244 if string == '':
245 string = '0'
246 return string
247
248 def _parenstr(self, always=False):
249 string = str(self)
250 if not always and (self.isconstant() or self.issymbol()):
251 return string
252 else:
253 return '({})'.format(string)
254
255 def __repr__(self):
256 string = '{}({{'.format(self.__class__.__name__)
257 for i, (symbol, coefficient) in enumerate(self.coefficients()):
258 if i != 0:
259 string += ', '
260 string += '{!r}: {!r}'.format(symbol, coefficient)
261 string += '}}, {!r})'.format(self.constant)
262 return string
263
264 @_polymorphic_method
265 def __eq__(self, other):
266 # "normal" equality
267 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
268 return isinstance(other, Expression) and \
269 self._coefficients == other._coefficients and \
270 self.constant == other.constant
271
272 def __hash__(self):
273 return hash((tuple(sorted(self._coefficients.items())), self._constant))
274
275 def _toint(self):
276 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
277 [value.denominator for value in self.values()])
278 return self * lcm
279
280 @_polymorphic_method
281 def _eq(self, other):
282 return Polyhedron(equalities=[(self - other)._toint()])
283
284 @_polymorphic_method
285 def __le__(self, other):
286 return Polyhedron(inequalities=[(other - self)._toint()])
287
288 @_polymorphic_method
289 def __lt__(self, other):
290 return Polyhedron(inequalities=[(other - self)._toint() - 1])
291
292 @_polymorphic_method
293 def __ge__(self, other):
294 return Polyhedron(inequalities=[(self - other)._toint()])
295
296 @_polymorphic_method
297 def __gt__(self, other):
298 return Polyhedron(inequalities=[(self - other)._toint() - 1])
299
300
301 class Constant(Expression):
302
303 def __new__(cls, numerator=0, denominator=None):
304 self = object().__new__(cls)
305 if denominator is None:
306 if isinstance(numerator, numbers.Rational):
307 self._constant = numerator
308 elif isinstance(numerator, Constant):
309 self._constant = numerator.constant
310 else:
311 raise TypeError('constant must be a rational number or a Constant instance')
312 else:
313 self._constant = Fraction(numerator, denominator)
314 self._coefficients = {}
315 self._symbols = ()
316 self._dimension = 0
317 return self
318
319 def isconstant(self):
320 return True
321
322 def __bool__(self):
323 return bool(self.constant)
324
325 def __repr__(self):
326 return '{}({!r})'.format(self.__class__.__name__, self._constant)
327
328
329 class Symbol(Expression):
330
331 def __new__(cls, name):
332 if isinstance(name, Symbol):
333 name = name.symbol
334 elif not isinstance(name, str):
335 raise TypeError('name must be a string or a Symbol instance')
336 self = object().__new__(cls)
337 self._coefficients = {name: 1}
338 self._constant = 0
339 self._symbols = tuple(name)
340 self._symbol = name
341 self._dimension = 1
342 return self
343
344 @property
345 def symbol(self):
346 return self._symbol
347
348 def issymbol(self):
349 return True
350
351 def __repr__(self):
352 return '{}({!r})'.format(self.__class__.__name__, self._symbol)
353
354 def symbols(names):
355 if isinstance(names, str):
356 names = names.replace(',', ' ').split()
357 return (Symbol(name) for name in names)
358
359
360 @_polymorphic_operator
361 def eq(a, b):
362 return a.__eq__(b)
363
364 @_polymorphic_operator
365 def le(a, b):
366 return a.__le__(b)
367
368 @_polymorphic_operator
369 def lt(a, b):
370 return a.__lt__(b)
371
372 @_polymorphic_operator
373 def ge(a, b):
374 return a.__ge__(b)
375
376 @_polymorphic_operator
377 def gt(a, b):
378 return a.__gt__(b)
379
380
381 class Polyhedron:
382 """
383 This class implements polyhedrons.
384 """
385
386 def __new__(cls, equalities=None, inequalities=None):
387 if isinstance(equalities, str):
388 if inequalities is not None:
389 raise TypeError('too many arguments')
390 return cls.fromstring(equalities)
391 self = super().__new__(cls)
392 self._equalities = []
393 if equalities is not None:
394 for constraint in equalities:
395 for value in constraint.values():
396 if value.denominator != 1:
397 raise TypeError('non-integer constraint: '
398 '{} == 0'.format(constraint))
399 self._equalities.append(constraint)
400 self._equalities = tuple(self._equalities)
401 self._inequalities = []
402 if inequalities is not None:
403 for constraint in inequalities:
404 for value in constraint.values():
405 if value.denominator != 1:
406 raise TypeError('non-integer constraint: '
407 '{} <= 0'.format(constraint))
408 self._inequalities.append(constraint)
409 self._inequalities = tuple(self._inequalities)
410 self._constraints = self._equalities + self._inequalities
411 self._symbols = set()
412 for constraint in self._constraints:
413 self.symbols.update(constraint.symbols)
414 self._symbols = tuple(sorted(self._symbols))
415 return self
416
417 @classmethod
418 def fromstring(cls, string):
419 raise NotImplementedError
420
421 @property
422 def equalities(self):
423 return self._equalities
424
425 @property
426 def inequalities(self):
427 return self._inequalities
428
429 @property
430 def constraints(self):
431 return self._constraints
432
433 @property
434 def symbols(self):
435 return self._symbols
436
437 @property
438 def dimension(self):
439 return len(self.symbols)
440
441 def __bool__(self):
442 return not self.is_empty()
443
444 def __contains__(self, value):
445 # is the value in the polyhedron?
446 raise NotImplementedError
447
448 def __eq__(self, other):
449 # works correctly when symbols is not passed
450 # should be equal if values are the same even if symbols are different
451 bset = self._toisl()
452 other = other._toisl()
453 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
454
455 def isempty(self):
456 bset = self._toisl()
457 return bool(libisl.isl_basic_set_is_empty(bset))
458
459 def isuniverse(self):
460 bset = self._toisl()
461 return bool(libisl.isl_basic_set_is_universe(bset))
462
463 def isdisjoint(self, other):
464 # return true if the polyhedron has no elements in common with other
465 #symbols = self._symbolunion(other)
466 bset = self._toisl()
467 other = other._toisl()
468 return bool(libisl.isl_set_is_disjoint(bset, other))
469
470 def issubset(self, other):
471 # check if self(bset) is a subset of other
472 symbols = self._symbolunion(other)
473 bset = self._toisl(symbols)
474 other = other._toisl(symbols)
475 return bool(libisl.isl_set_is_strict_subset(other, bset))
476
477 def __le__(self, other):
478 return self.issubset(other)
479
480 def __lt__(self, other):
481 symbols = self._symbolunion(other)
482 bset = self._toisl(symbols)
483 other = other._toisl(symbols)
484 return bool(libisl.isl_set_is_strict_subset(other, bset))
485
486 def issuperset(self, other):
487 # test whether every element in other is in the polyhedron
488 raise NotImplementedError
489
490 def __ge__(self, other):
491 return self.issuperset(other)
492
493 def __gt__(self, other):
494 symbols = self._symbolunion(other)
495 bset = self._toisl(symbols)
496 other = other._toisl(symbols)
497 bool(libisl.isl_set_is_strict_subset(other, bset))
498 raise NotImplementedError
499
500 def union(self, *others):
501 # return a new polyhedron with elements from the polyhedron and all
502 # others (convex union)
503 raise NotImplementedError
504
505 def __or__(self, other):
506 return self.union(other)
507
508 def intersection(self, *others):
509 # return a new polyhedron with elements common to the polyhedron and all
510 # others
511 # a poor man's implementation could be:
512 # equalities = list(self.equalities)
513 # inequalities = list(self.inequalities)
514 # for other in others:
515 # equalities.extend(other.equalities)
516 # inequalities.extend(other.inequalities)
517 # return self.__class__(equalities, inequalities)
518 raise NotImplementedError
519
520 def __and__(self, other):
521 return self.intersection(other)
522
523 def difference(self, other):
524 # return a new polyhedron with elements in the polyhedron that are not in the other
525 symbols = self._symbolunion(other)
526 bset = self._toisl(symbols)
527 other = other._toisl(symbols)
528 difference = libisl.isl_set_subtract(bset, other)
529 return difference
530
531
532 def __sub__(self, other):
533 return self.difference(other)
534
535 def __str__(self):
536 constraints = []
537 for constraint in self.equalities:
538 constraints.append('{} == 0'.format(constraint))
539 for constraint in self.inequalities:
540 constraints.append('{} >= 0'.format(constraint))
541 return '{{{}}}'.format(', '.join(constraints))
542
543 def __repr__(self):
544 if self.isempty():
545 return 'Empty'
546 elif self.isuniverse():
547 return 'Universe'
548 else:
549 equalities = list(self.equalities)
550 inequalities = list(self.inequalities)
551 return '{}(equalities={!r}, inequalities={!r})' \
552 ''.format(self.__class__.__name__, equalities, inequalities)
553
554 def _symbolunion(self, *others):
555 symbols = set(self.symbols)
556 for other in others:
557 symbols.update(other.symbols)
558 return sorted(symbols)
559
560 def _toisl(self, symbols=None):
561 if symbols is None:
562 symbols = self.symbols
563 num_coefficients = len(symbols)
564 space = libisl.isl_space_set_alloc(_main_ctx, 0, num_coefficients)
565 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
566 ls = libisl.isl_local_space_from_space(space)
567 #if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set
568 for eq in self.equalities:
569 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
570 coeff_eq = dict(eq.coefficients())
571 if eq.constant:
572 value = str(eq.constant).encode()
573 val = libisl.isl_val_read_from_str(_main_ctx, value)
574 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
575 for eq in coeff_eq:
576 number = str(coeff_eq.get(eq)).encode()
577 num = libisl.isl_val_read_from_str(_main_ctx, number)
578 iden = symbols.index(eq)
579 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
580 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
581 for ineq in self.inequalities:
582 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
583 coeff_in = dict(ineq.coefficients())
584 if ineq.constant:
585 value = str(ineq.constant).encode()
586 val = libisl.isl_val_read_from_str(_main_ctx, value)
587 cin = libisl.isl_constraint_set_constant_val(cin, val)
588 for ineq in coeff_in:
589 number = str(coeff_in.get(ineq)).encode()
590 num = libisl.isl_val_read_from_str(_main_ctx, number)
591 iden = symbols.index(ineq)
592 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
593 bset = libisl.isl_basic_set_add_constraint(bset, cin)
594 bset = isl.BasicSet(bset)
595 return bset
596
597 @classmethod
598 def _fromisl(cls, bset):
599 raise NotImplementedError
600 equalities = ...
601 inequalities = ...
602 return cls(equalities, inequalities)
603 '''takes basic set in isl form and puts back into python version of polyhedron
604 isl example code gives isl form as:
605 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
606 our printer is giving form as:
607 { [i0, i1] : 2i1 >= -2 - i0 } '''
608
609 Empty = eq(0,1)
610 Universe = Polyhedron()
611
612 if __name__ == '__main__':
613 ex1 = Expression(coefficients={'a': 6, 'b': 6}, constant= 3) #this is the expression that does not work (even without adding values)
614 ex2 = Expression(coefficients={'x': 4, 'y': 2}, constant= 3)
615 p = Polyhedron(equalities=[ex2])
616 p2 = Polyhedron(equalities=[ex2])
617 print(p._toisl()) # checking is values works for toisl