3939c71d038030028cfbe177ca0974b64693175e
[linpy.git] / pypol / linear.py
1 import functools
2 import numbers
3
4 from fractions import Fraction, gcd
5
6 from pypol import isl
7 from pypol.isl import libisl
8
9
10 __all__ = [
11 'Expression', 'Constant', 'Symbol', 'symbols',
12 'eq', 'le', 'lt', 'ge', 'gt',
13 'Polyhedron',
14 'Empty', 'Universe'
15 ]
16
17
18 def _polymorphic_method(func):
19 @functools.wraps(func)
20 def wrapper(a, b):
21 if isinstance(b, Expression):
22 return func(a, b)
23 if isinstance(b, numbers.Rational):
24 b = Constant(b)
25 return func(a, b)
26 return NotImplemented
27 return wrapper
28
29 def _polymorphic_operator(func):
30 # A polymorphic operator should call a polymorphic method, hence we just
31 # have to test the left operand.
32 @functools.wraps(func)
33 def wrapper(a, b):
34 if isinstance(a, numbers.Rational):
35 a = Constant(a)
36 return func(a, b)
37 elif isinstance(a, Expression):
38 return func(a, b)
39 raise TypeError('arguments must be linear expressions')
40 return wrapper
41
42
43 _main_ctx = isl.Context()
44
45
46 class Expression:
47 """
48 This class implements linear expressions.
49 """
50
51 def __new__(cls, coefficients=None, constant=0):
52 if isinstance(coefficients, str):
53 if constant:
54 raise TypeError('too many arguments')
55 return cls.fromstring(coefficients)
56 if isinstance(coefficients, dict):
57 coefficients = coefficients.items()
58 if coefficients is None:
59 return Constant(constant)
60 coefficients = [(symbol, coefficient)
61 for symbol, coefficient in coefficients if coefficient != 0]
62 if len(coefficients) == 0:
63 return Constant(constant)
64 elif len(coefficients) == 1 and constant == 0:
65 symbol, coefficient = coefficients[0]
66 if coefficient == 1:
67 return Symbol(symbol)
68 self = object().__new__(cls)
69 self._coefficients = {}
70 for symbol, coefficient in coefficients:
71 if isinstance(symbol, Symbol):
72 symbol = symbol.name
73 elif not isinstance(symbol, str):
74 raise TypeError('symbols must be strings or Symbol instances')
75 if isinstance(coefficient, Constant):
76 coefficient = coefficient.constant
77 if not isinstance(coefficient, numbers.Rational):
78 raise TypeError('coefficients must be rational numbers or Constant instances')
79 self._coefficients[symbol] = coefficient
80 if isinstance(constant, Constant):
81 constant = constant.constant
82 if not isinstance(constant, numbers.Rational):
83 raise TypeError('constant must be a rational number or a Constant instance')
84 self._constant = constant
85 self._symbols = tuple(sorted(self._coefficients))
86 self._dimension = len(self._symbols)
87 return self
88
89 @classmethod
90 def fromstring(cls, string):
91 raise NotImplementedError
92
93 @property
94 def symbols(self):
95 return self._symbols
96
97 @property
98 def dimension(self):
99 return self._dimension
100
101 def coefficient(self, symbol):
102 if isinstance(symbol, Symbol):
103 symbol = str(symbol)
104 elif not isinstance(symbol, str):
105 raise TypeError('symbol must be a string or a Symbol instance')
106 try:
107 return self._coefficients[symbol]
108 except KeyError:
109 return 0
110
111 __getitem__ = coefficient
112
113 def coefficients(self):
114 for symbol in self.symbols:
115 yield symbol, self.coefficient(symbol)
116
117 @property
118 def constant(self):
119 return self._constant
120
121 def isconstant(self):
122 return False
123
124 def values(self):
125 for symbol in self.symbols:
126 yield self.coefficient(symbol)
127 yield self.constant
128
129 def issymbol(self):
130 return False
131
132 def __bool__(self):
133 return True
134
135 def __pos__(self):
136 return self
137
138 def __neg__(self):
139 return self * -1
140
141 @_polymorphic_method
142 def __add__(self, other):
143 coefficients = dict(self.coefficients())
144 for symbol, coefficient in other.coefficients():
145 if symbol in coefficients:
146 coefficients[symbol] += coefficient
147 else:
148 coefficients[symbol] = coefficient
149 constant = self.constant + other.constant
150 return Expression(coefficients, constant)
151
152 __radd__ = __add__
153
154 @_polymorphic_method
155 def __sub__(self, other):
156 coefficients = dict(self.coefficients())
157 for symbol, coefficient in other.coefficients():
158 if symbol in coefficients:
159 coefficients[symbol] -= coefficient
160 else:
161 coefficients[symbol] = -coefficient
162 constant = self.constant - other.constant
163 return Expression(coefficients, constant)
164
165 def __rsub__(self, other):
166 return -(self - other)
167
168 @_polymorphic_method
169 def __mul__(self, other):
170 if other.isconstant():
171 coefficients = dict(self.coefficients())
172 for symbol in coefficients:
173 coefficients[symbol] *= other.constant
174 constant = self.constant * other.constant
175 return Expression(coefficients, constant)
176 if isinstance(other, Expression) and not self.isconstant():
177 raise ValueError('non-linear expression: '
178 '{} * {}'.format(self._parenstr(), other._parenstr()))
179 return NotImplemented
180
181 __rmul__ = __mul__
182
183 @_polymorphic_method
184 def __truediv__(self, other):
185 if other.isconstant():
186 coefficients = dict(self.coefficients())
187 for symbol in coefficients:
188 coefficients[symbol] = \
189 Fraction(coefficients[symbol], other.constant)
190 constant = Fraction(self.constant, other.constant)
191 return Expression(coefficients, constant)
192 if isinstance(other, Expression):
193 raise ValueError('non-linear expression: '
194 '{} / {}'.format(self._parenstr(), other._parenstr()))
195 return NotImplemented
196
197 def __rtruediv__(self, other):
198 if isinstance(other, self):
199 if self.isconstant():
200 constant = Fraction(other, self.constant)
201 return Expression(constant=constant)
202 else:
203 raise ValueError('non-linear expression: '
204 '{} / {}'.format(other._parenstr(), self._parenstr()))
205 return NotImplemented
206
207 def __str__(self):
208 string = ''
209 i = 0
210 for symbol in self.symbols:
211 coefficient = self.coefficient(symbol)
212 if coefficient == 1:
213 if i == 0:
214 string += symbol
215 else:
216 string += ' + {}'.format(symbol)
217 elif coefficient == -1:
218 if i == 0:
219 string += '-{}'.format(symbol)
220 else:
221 string += ' - {}'.format(symbol)
222 else:
223 if i == 0:
224 string += '{}*{}'.format(coefficient, symbol)
225 elif coefficient > 0:
226 string += ' + {}*{}'.format(coefficient, symbol)
227 else:
228 assert coefficient < 0
229 coefficient *= -1
230 string += ' - {}*{}'.format(coefficient, symbol)
231 i += 1
232 constant = self.constant
233 if constant != 0 and i == 0:
234 string += '{}'.format(constant)
235 elif constant > 0:
236 string += ' + {}'.format(constant)
237 elif constant < 0:
238 constant *= -1
239 string += ' - {}'.format(constant)
240 if string == '':
241 string = '0'
242 return string
243
244 def _parenstr(self, always=False):
245 string = str(self)
246 if not always and (self.isconstant() or self.issymbol()):
247 return string
248 else:
249 return '({})'.format(string)
250
251 def __repr__(self):
252 string = '{}({{'.format(self.__class__.__name__)
253 for i, (symbol, coefficient) in enumerate(self.coefficients()):
254 if i != 0:
255 string += ', '
256 string += '{!r}: {!r}'.format(symbol, coefficient)
257 string += '}}, {!r})'.format(self.constant)
258 return string
259
260 @_polymorphic_method
261 def __eq__(self, other):
262 # "normal" equality
263 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
264 return isinstance(other, Expression) and \
265 self._coefficients == other._coefficients and \
266 self.constant == other.constant
267
268 def __hash__(self):
269 return hash((tuple(sorted(self._coefficients.items())), self._constant))
270
271 def _toint(self):
272 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
273 [value.denominator for value in self.values()])
274 return self * lcm
275
276 @_polymorphic_method
277 def _eq(self, other):
278 return Polyhedron(equalities=[(self - other)._toint()])
279
280 @_polymorphic_method
281 def __le__(self, other):
282 return Polyhedron(inequalities=[(other - self)._toint()])
283
284 @_polymorphic_method
285 def __lt__(self, other):
286 return Polyhedron(inequalities=[(other - self)._toint() - 1])
287
288 @_polymorphic_method
289 def __ge__(self, other):
290 return Polyhedron(inequalities=[(self - other)._toint()])
291
292 @_polymorphic_method
293 def __gt__(self, other):
294 return Polyhedron(inequalities=[(self - other)._toint() - 1])
295
296
297 class Constant(Expression):
298
299 def __new__(cls, numerator=0, denominator=None):
300 self = object().__new__(cls)
301 if denominator is None:
302 if isinstance(numerator, numbers.Rational):
303 self._constant = numerator
304 elif isinstance(numerator, Constant):
305 self._constant = numerator.constant
306 else:
307 raise TypeError('constant must be a rational number or a Constant instance')
308 else:
309 self._constant = Fraction(numerator, denominator)
310 self._coefficients = {}
311 self._symbols = ()
312 self._dimension = 0
313 return self
314
315 def isconstant(self):
316 return True
317
318 def __bool__(self):
319 return bool(self.constant)
320
321 def __repr__(self):
322 return '{}({!r})'.format(self.__class__.__name__, self._constant)
323
324
325 class Symbol(Expression):
326
327 def __new__(cls, name):
328 if isinstance(name, Symbol):
329 name = name.name
330 elif not isinstance(name, str):
331 raise TypeError('name must be a string or a Symbol instance')
332 self = object().__new__(cls)
333 self._coefficients = {name: 1}
334 self._constant = 0
335 self._symbols = tuple(name)
336 self._name = name
337 self._dimension = 1
338 return self
339
340 @property
341 def name(self):
342 return self._name
343
344 def issymbol(self):
345 return True
346
347 def __repr__(self):
348 return '{}({!r})'.format(self.__class__.__name__, self._name)
349
350 def symbols(names):
351 if isinstance(names, str):
352 names = names.replace(',', ' ').split()
353 return (Symbol(name) for name in names)
354
355
356 @_polymorphic_operator
357 def eq(a, b):
358 return a.__eq__(b)
359
360 @_polymorphic_operator
361 def le(a, b):
362 return a.__le__(b)
363
364 @_polymorphic_operator
365 def lt(a, b):
366 return a.__lt__(b)
367
368 @_polymorphic_operator
369 def ge(a, b):
370 return a.__ge__(b)
371
372 @_polymorphic_operator
373 def gt(a, b):
374 return a.__gt__(b)
375
376
377 class Polyhedron:
378 """
379 This class implements polyhedrons.
380 """
381
382 def __new__(cls, equalities=None, inequalities=None):
383 if isinstance(equalities, str):
384 if inequalities is not None:
385 raise TypeError('too many arguments')
386 return cls.fromstring(equalities)
387 self = super().__new__(cls)
388 self._equalities = []
389 if equalities is not None:
390 for constraint in equalities:
391 for value in constraint.values():
392 if value.denominator != 1:
393 raise TypeError('non-integer constraint: '
394 '{} == 0'.format(constraint))
395 self._equalities.append(constraint)
396 self._equalities = tuple(self._equalities)
397 self._inequalities = []
398 if inequalities is not None:
399 for constraint in inequalities:
400 for value in constraint.values():
401 if value.denominator != 1:
402 raise TypeError('non-integer constraint: '
403 '{} <= 0'.format(constraint))
404 self._inequalities.append(constraint)
405 self._inequalities = tuple(self._inequalities)
406 self._constraints = self._equalities + self._inequalities
407 self._symbols = set()
408 for constraint in self._constraints:
409 self.symbols.update(constraint.symbols)
410 self._symbols = tuple(sorted(self._symbols))
411 return self
412
413 @classmethod
414 def fromstring(cls, string):
415 raise NotImplementedError
416
417 @property
418 def equalities(self):
419 return self._equalities
420
421 @property
422 def inequalities(self):
423 return self._inequalities
424
425 @property
426 def constraints(self):
427 return self._constraints
428
429 @property
430 def symbols(self):
431 return self._symbols
432
433 @property
434 def dimension(self):
435 return len(self.symbols)
436
437 def __bool__(self):
438 return not self.is_empty()
439
440 def __contains__(self, value):
441 # is the value in the polyhedron?
442 raise NotImplementedError
443
444 def __eq__(self, other):
445 # works correctly when symbols is not passed
446 # should be equal if values are the same even if symbols are different
447 bset = self._toisl()
448 other = other._toisl()
449 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
450
451 def isempty(self):
452 bset = self._toisl()
453 return bool(libisl.isl_basic_set_is_empty(bset))
454
455 def isuniverse(self):
456 bset = self._toisl()
457 return bool(libisl.isl_basic_set_is_universe(bset))
458
459 def isdisjoint(self, other):
460 # return true if the polyhedron has no elements in common with other
461 #symbols = self._symbolunion(other)
462 bset = self._toisl()
463 other = other._toisl()
464 return bool(libisl.isl_set_is_disjoint(bset, other))
465
466 def issubset(self, other):
467 # check if self(bset) is a subset of other
468 symbols = self._symbolunion(other)
469 bset = self._toisl(symbols)
470 other = other._toisl(symbols)
471 return bool(libisl.isl_set_is_strict_subset(other, bset))
472
473 def __le__(self, other):
474 return self.issubset(other)
475
476 def __lt__(self, other):
477 symbols = self._symbolunion(other)
478 bset = self._toisl(symbols)
479 other = other._toisl(symbols)
480 return bool(libisl.isl_set_is_strict_subset(other, bset))
481
482 def issuperset(self, other):
483 # test whether every element in other is in the polyhedron
484 raise NotImplementedError
485
486 def __ge__(self, other):
487 return self.issuperset(other)
488
489 def __gt__(self, other):
490 symbols = self._symbolunion(other)
491 bset = self._toisl(symbols)
492 other = other._toisl(symbols)
493 bool(libisl.isl_set_is_strict_subset(other, bset))
494 raise NotImplementedError
495
496 def union(self, *others):
497 # return a new polyhedron with elements from the polyhedron and all
498 # others (convex union)
499 raise NotImplementedError
500
501 def __or__(self, other):
502 return self.union(other)
503
504 def intersection(self, *others):
505 # return a new polyhedron with elements common to the polyhedron and all
506 # others
507 # a poor man's implementation could be:
508 # equalities = list(self.equalities)
509 # inequalities = list(self.inequalities)
510 # for other in others:
511 # equalities.extend(other.equalities)
512 # inequalities.extend(other.inequalities)
513 # return self.__class__(equalities, inequalities)
514 raise NotImplementedError
515
516 def __and__(self, other):
517 return self.intersection(other)
518
519 def difference(self, other):
520 # return a new polyhedron with elements in the polyhedron that are not in the other
521 symbols = self._symbolunion(other)
522 bset = self._toisl(symbols)
523 other = other._toisl(symbols)
524 difference = libisl.isl_set_subtract(bset, other)
525 return difference
526
527 def __sub__(self, other):
528 return self.difference(other)
529
530 def __str__(self):
531 constraints = []
532 for constraint in self.equalities:
533 constraints.append('{} == 0'.format(constraint))
534 for constraint in self.inequalities:
535 constraints.append('{} >= 0'.format(constraint))
536 return '{{{}}}'.format(', '.join(constraints))
537
538 def __repr__(self):
539 if self.isempty():
540 return 'Empty'
541 elif self.isuniverse():
542 return 'Universe'
543 else:
544 equalities = list(self.equalities)
545 inequalities = list(self.inequalities)
546 return '{}(equalities={!r}, inequalities={!r})' \
547 ''.format(self.__class__.__name__, equalities, inequalities)
548
549 def _symbolunion(self, *others):
550 symbols = set(self.symbols)
551 for other in others:
552 symbols.update(other.symbols)
553 return sorted(symbols)
554
555 def _toisl(self, symbols=None):
556 if symbols is None:
557 symbols = self.symbols
558 dimension = len(symbols)
559 space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
560 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
561 ls = libisl.isl_local_space_from_space(space)
562 for equality in self.equalities:
563 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
564 for symbol, coefficient in equality.coefficients():
565 val = str(coefficient).encode()
566 val = libisl.isl_val_read_from_str(_main_ctx, val)
567 dim = symbols.index(symbol)
568 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
569 if equality.constant != 0:
570 val = str(equality.constant).encode()
571 val = libisl.isl_val_read_from_str(_main_ctx, val)
572 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
573 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
574 for inequality in self.inequalities:
575 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
576 for symbol, coefficient in inequality.coefficients():
577 val = str(coefficient).encode()
578 val = libisl.isl_val_read_from_str(_main_ctx, val)
579 dim = symbols.index(symbol)
580 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
581 if inequality.constant != 0:
582 val = str(ineq.constant).encode()
583 val = libisl.isl_val_read_from_str(_main_ctx, val)
584 cin = libisl.isl_constraint_set_constant_val(cin, val)
585 bset = libisl.isl_basic_set_add_constraint(bset, cin)
586 bset = isl.BasicSet(bset)
587 return bset
588
589 @classmethod
590 def _fromisl(cls, bset):
591 raise NotImplementedError
592 equalities = ...
593 inequalities = ...
594 return cls(equalities, inequalities)
595 '''takes basic set in isl form and puts back into python version of polyhedron
596 isl example code gives isl form as:
597 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
598 our printer is giving form as:
599 { [i0, i1] : 2i1 >= -2 - i0 } '''
600
601 Empty = eq(0,1)
602 Universe = Polyhedron()
603
604 if __name__ == '__main__':
605 e1 = Expression(coefficients={'a': 2, 'b': 2}, constant= 1)
606 p1 = Polyhedron(equalities=[e1]) # empty
607 e2 = Expression(coefficients={'x': 3, 'y': 2}, constant= 3)
608 p2 = Polyhedron(equalities=[e2]) # not empty
609 print(p1._toisl())
610 print(p2._toisl())