Improve representation of Constants
[linpy.git] / pypol / linear.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from fractions import Fraction, gcd
7
8 from pypol import isl
9 from pypol.isl import libisl
10
11
12 __all__ = [
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'Empty', 'Universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = Constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = Constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 _main_ctx = isl.Context()
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 __slots__ = (
54 '_coefficients',
55 '_constant',
56 '_symbols',
57 '_dimension',
58 )
59
60 def __new__(cls, coefficients=None, constant=0):
61 if isinstance(coefficients, str):
62 if constant:
63 raise TypeError('too many arguments')
64 return cls.fromstring(coefficients)
65 if isinstance(coefficients, dict):
66 coefficients = coefficients.items()
67 if coefficients is None:
68 return Constant(constant)
69 coefficients = [(symbol, coefficient)
70 for symbol, coefficient in coefficients if coefficient != 0]
71 if len(coefficients) == 0:
72 return Constant(constant)
73 elif len(coefficients) == 1 and constant == 0:
74 symbol, coefficient = coefficients[0]
75 if coefficient == 1:
76 return Symbol(symbol)
77 self = object().__new__(cls)
78 self._coefficients = {}
79 for symbol, coefficient in coefficients:
80 if isinstance(symbol, Symbol):
81 symbol = symbol.name
82 elif not isinstance(symbol, str):
83 raise TypeError('symbols must be strings or Symbol instances')
84 if isinstance(coefficient, Constant):
85 coefficient = coefficient.constant
86 if not isinstance(coefficient, numbers.Rational):
87 raise TypeError('coefficients must be rational numbers or Constant instances')
88 self._coefficients[symbol] = coefficient
89 if isinstance(constant, Constant):
90 constant = constant.constant
91 if not isinstance(constant, numbers.Rational):
92 raise TypeError('constant must be a rational number or a Constant instance')
93 self._constant = constant
94 self._symbols = tuple(sorted(self._coefficients))
95 self._dimension = len(self._symbols)
96 return self
97
98 @classmethod
99 def _fromast(cls, node):
100 if isinstance(node, ast.Module) and len(node.body) == 1:
101 return cls._fromast(node.body[0])
102 elif isinstance(node, ast.Expr):
103 return cls._fromast(node.value)
104 elif isinstance(node, ast.Name):
105 return Symbol(node.id)
106 elif isinstance(node, ast.Num):
107 return Constant(node.n)
108 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
109 return -cls._fromast(node.operand)
110 elif isinstance(node, ast.BinOp):
111 left = cls._fromast(node.left)
112 right = cls._fromast(node.right)
113 if isinstance(node.op, ast.Add):
114 return left + right
115 elif isinstance(node.op, ast.Sub):
116 return left - right
117 elif isinstance(node.op, ast.Mult):
118 return left * right
119 elif isinstance(node.op, ast.Div):
120 return left / right
121 raise SyntaxError('invalid syntax')
122
123 @classmethod
124 def fromstring(cls, string):
125 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
126 tree = ast.parse(string, 'eval')
127 return cls._fromast(tree)
128
129 @property
130 def symbols(self):
131 return self._symbols
132
133 @property
134 def dimension(self):
135 return self._dimension
136
137 def coefficient(self, symbol):
138 if isinstance(symbol, Symbol):
139 symbol = str(symbol)
140 elif not isinstance(symbol, str):
141 raise TypeError('symbol must be a string or a Symbol instance')
142 try:
143 return self._coefficients[symbol]
144 except KeyError:
145 return 0
146
147 __getitem__ = coefficient
148
149 def coefficients(self):
150 for symbol in self.symbols:
151 yield symbol, self.coefficient(symbol)
152
153 @property
154 def constant(self):
155 return self._constant
156
157 def isconstant(self):
158 return False
159
160 def values(self):
161 for symbol in self.symbols:
162 yield self.coefficient(symbol)
163 yield self.constant
164
165 def issymbol(self):
166 return False
167
168 def __bool__(self):
169 return True
170
171 def __pos__(self):
172 return self
173
174 def __neg__(self):
175 return self * -1
176
177 @_polymorphic_method
178 def __add__(self, other):
179 coefficients = dict(self.coefficients())
180 for symbol, coefficient in other.coefficients():
181 if symbol in coefficients:
182 coefficients[symbol] += coefficient
183 else:
184 coefficients[symbol] = coefficient
185 constant = self.constant + other.constant
186 return Expression(coefficients, constant)
187
188 __radd__ = __add__
189
190 @_polymorphic_method
191 def __sub__(self, other):
192 coefficients = dict(self.coefficients())
193 for symbol, coefficient in other.coefficients():
194 if symbol in coefficients:
195 coefficients[symbol] -= coefficient
196 else:
197 coefficients[symbol] = -coefficient
198 constant = self.constant - other.constant
199 return Expression(coefficients, constant)
200
201 def __rsub__(self, other):
202 return -(self - other)
203
204 @_polymorphic_method
205 def __mul__(self, other):
206 if other.isconstant():
207 coefficients = dict(self.coefficients())
208 for symbol in coefficients:
209 coefficients[symbol] *= other.constant
210 constant = self.constant * other.constant
211 return Expression(coefficients, constant)
212 if isinstance(other, Expression) and not self.isconstant():
213 raise ValueError('non-linear expression: '
214 '{} * {}'.format(self._parenstr(), other._parenstr()))
215 return NotImplemented
216
217 __rmul__ = __mul__
218
219 @_polymorphic_method
220 def __truediv__(self, other):
221 if other.isconstant():
222 coefficients = dict(self.coefficients())
223 for symbol in coefficients:
224 coefficients[symbol] = \
225 Fraction(coefficients[symbol], other.constant)
226 constant = Fraction(self.constant, other.constant)
227 return Expression(coefficients, constant)
228 if isinstance(other, Expression):
229 raise ValueError('non-linear expression: '
230 '{} / {}'.format(self._parenstr(), other._parenstr()))
231 return NotImplemented
232
233 def __rtruediv__(self, other):
234 if isinstance(other, self):
235 if self.isconstant():
236 constant = Fraction(other, self.constant)
237 return Expression(constant=constant)
238 else:
239 raise ValueError('non-linear expression: '
240 '{} / {}'.format(other._parenstr(), self._parenstr()))
241 return NotImplemented
242
243 def __str__(self):
244 string = ''
245 i = 0
246 for symbol in self.symbols:
247 coefficient = self.coefficient(symbol)
248 if coefficient == 1:
249 if i == 0:
250 string += symbol
251 else:
252 string += ' + {}'.format(symbol)
253 elif coefficient == -1:
254 if i == 0:
255 string += '-{}'.format(symbol)
256 else:
257 string += ' - {}'.format(symbol)
258 else:
259 if i == 0:
260 string += '{}*{}'.format(coefficient, symbol)
261 elif coefficient > 0:
262 string += ' + {}*{}'.format(coefficient, symbol)
263 else:
264 assert coefficient < 0
265 coefficient *= -1
266 string += ' - {}*{}'.format(coefficient, symbol)
267 i += 1
268 constant = self.constant
269 if constant != 0 and i == 0:
270 string += '{}'.format(constant)
271 elif constant > 0:
272 string += ' + {}'.format(constant)
273 elif constant < 0:
274 constant *= -1
275 string += ' - {}'.format(constant)
276 if string == '':
277 string = '0'
278 return string
279
280 def _parenstr(self, always=False):
281 string = str(self)
282 if not always and (self.isconstant() or self.issymbol()):
283 return string
284 else:
285 return '({})'.format(string)
286
287 def __repr__(self):
288 string = '{}({{'.format(self.__class__.__name__)
289 for i, (symbol, coefficient) in enumerate(self.coefficients()):
290 if i != 0:
291 string += ', '
292 string += '{!r}: {!r}'.format(symbol, coefficient)
293 string += '}}, {!r})'.format(self.constant)
294 return string
295
296 @_polymorphic_method
297 def __eq__(self, other):
298 # "normal" equality
299 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
300 return isinstance(other, Expression) and \
301 self._coefficients == other._coefficients and \
302 self.constant == other.constant
303
304 def __hash__(self):
305 return hash((tuple(sorted(self._coefficients.items())), self._constant))
306
307 def _toint(self):
308 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
309 [value.denominator for value in self.values()])
310 return self * lcm
311
312 @_polymorphic_method
313 def _eq(self, other):
314 return Polyhedron(equalities=[(self - other)._toint()])
315
316 @_polymorphic_method
317 def __le__(self, other):
318 return Polyhedron(inequalities=[(other - self)._toint()])
319
320 @_polymorphic_method
321 def __lt__(self, other):
322 return Polyhedron(inequalities=[(other - self)._toint() - 1])
323
324 @_polymorphic_method
325 def __ge__(self, other):
326 return Polyhedron(inequalities=[(self - other)._toint()])
327
328 @_polymorphic_method
329 def __gt__(self, other):
330 return Polyhedron(inequalities=[(self - other)._toint() - 1])
331
332
333 class Constant(Expression):
334
335 def __new__(cls, numerator=0, denominator=None):
336 self = object().__new__(cls)
337 if denominator is None:
338 if isinstance(numerator, numbers.Rational):
339 self._constant = numerator
340 elif isinstance(numerator, Constant):
341 self._constant = numerator.constant
342 else:
343 raise TypeError('constant must be a rational number or a Constant instance')
344 else:
345 self._constant = Fraction(numerator, denominator)
346 self._coefficients = {}
347 self._symbols = ()
348 self._dimension = 0
349 return self
350
351 def isconstant(self):
352 return True
353
354 def __bool__(self):
355 return bool(self.constant)
356
357 def __repr__(self):
358 if self.constant.denominator == 1:
359 return '{}({!r})'.format(self.__class__.__name__, self.constant)
360 else:
361 return '{}({!r}, {!r})'.format(self.__class__.__name__,
362 self.constant.numerator, self.constant.denominator)
363
364 class Symbol(Expression):
365
366 __slots__ = Expression.__slots__ + (
367 '_name',
368 )
369
370 def __new__(cls, name):
371 if isinstance(name, Symbol):
372 name = name.name
373 elif not isinstance(name, str):
374 raise TypeError('name must be a string or a Symbol instance')
375 self = object().__new__(cls)
376 self._coefficients = {name: 1}
377 self._constant = 0
378 self._symbols = tuple(name)
379 self._name = name
380 self._dimension = 1
381 return self
382
383 @property
384 def name(self):
385 return self._name
386
387 def issymbol(self):
388 return True
389
390 def __repr__(self):
391 return '{}({!r})'.format(self.__class__.__name__, self._name)
392
393 def symbols(names):
394 if isinstance(names, str):
395 names = names.replace(',', ' ').split()
396 return (Symbol(name) for name in names)
397
398
399 @_polymorphic_operator
400 def eq(a, b):
401 return a.__eq__(b)
402
403 @_polymorphic_operator
404 def le(a, b):
405 return a.__le__(b)
406
407 @_polymorphic_operator
408 def lt(a, b):
409 return a.__lt__(b)
410
411 @_polymorphic_operator
412 def ge(a, b):
413 return a.__ge__(b)
414
415 @_polymorphic_operator
416 def gt(a, b):
417 return a.__gt__(b)
418
419
420 class Polyhedron:
421 """
422 This class implements polyhedrons.
423 """
424
425 __slots__ = (
426 '_equalities',
427 '_inequalities',
428 '_constraints',
429 '_symbols',
430 )
431
432 def __new__(cls, equalities=None, inequalities=None):
433 if isinstance(equalities, str):
434 if inequalities is not None:
435 raise TypeError('too many arguments')
436 return cls.fromstring(equalities)
437 self = super().__new__(cls)
438 self._equalities = []
439 if equalities is not None:
440 for constraint in equalities:
441 for value in constraint.values():
442 if value.denominator != 1:
443 raise TypeError('non-integer constraint: '
444 '{} == 0'.format(constraint))
445 self._equalities.append(constraint)
446 self._equalities = tuple(self._equalities)
447 self._inequalities = []
448 if inequalities is not None:
449 for constraint in inequalities:
450 for value in constraint.values():
451 if value.denominator != 1:
452 raise TypeError('non-integer constraint: '
453 '{} <= 0'.format(constraint))
454 self._inequalities.append(constraint)
455 self._inequalities = tuple(self._inequalities)
456 self._constraints = self._equalities + self._inequalities
457 self._symbols = set()
458 for constraint in self._constraints:
459 self.symbols.update(constraint.symbols)
460 self._symbols = tuple(sorted(self._symbols))
461 return self
462
463 @classmethod
464 def _fromast(cls, node):
465 if isinstance(node, ast.Module) and len(node.body) == 1:
466 return cls._fromast(node.body[0])
467 elif isinstance(node, ast.Expr):
468 return cls._fromast(node.value)
469 elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd):
470 equalities1, inequalities1 = cls._fromast(node.left)
471 equalities2, inequalities2 = cls._fromast(node.right)
472 equalities = equalities1 + equalities2
473 inequalities = inequalities1 + inequalities2
474 return equalities, inequalities
475 elif isinstance(node, ast.Compare):
476 equalities = []
477 inequalities = []
478 left = Expression._fromast(node.left)
479 for i in range(len(node.ops)):
480 op = node.ops[i]
481 right = Expression._fromast(node.comparators[i])
482 if isinstance(op, ast.Lt):
483 inequalities.append(right - left - 1)
484 elif isinstance(op, ast.LtE):
485 inequalities.append(right - left)
486 elif isinstance(op, ast.Eq):
487 equalities.append(left - right)
488 elif isinstance(op, ast.GtE):
489 inequalities.append(left - right)
490 elif isinstance(op, ast.Gt):
491 inequalities.append(left - right - 1)
492 else:
493 break
494 left = right
495 else:
496 return equalities, inequalities
497 raise SyntaxError('invalid syntax')
498
499 @classmethod
500 def fromstring(cls, string):
501 string = string.strip()
502 string = re.sub(r'^\{\s*|\s*\}$', '', string)
503 string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
504 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
505 tokens = re.split(r',|;|and|&&|/\\|∧', string, flags=re.I)
506 tokens = ['({})'.format(token) for token in tokens]
507 string = ' & '.join(tokens)
508 tree = ast.parse(string, 'eval')
509 equalities, inequalities = cls._fromast(tree)
510 return cls(equalities, inequalities)
511
512 @property
513 def equalities(self):
514 return self._equalities
515
516 @property
517 def inequalities(self):
518 return self._inequalities
519
520 @property
521 def constraints(self):
522 return self._constraints
523
524 @property
525 def symbols(self):
526 return self._symbols
527
528 @property
529 def dimension(self):
530 return len(self.symbols)
531
532 def __bool__(self):
533 return not self.is_empty()
534
535 def __contains__(self, value):
536 # is the value in the polyhedron?
537 raise NotImplementedError
538
539 def __eq__(self, other):
540 # works correctly when symbols is not passed
541 # should be equal if values are the same even if symbols are different
542 bset = self._toisl()
543 other = other._toisl()
544 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
545
546 def isempty(self):
547 bset = self._toisl()
548 return bool(libisl.isl_basic_set_is_empty(bset))
549
550 def isuniverse(self):
551 bset = self._toisl()
552 return bool(libisl.isl_basic_set_is_universe(bset))
553
554 def isdisjoint(self, other):
555 # return true if the polyhedron has no elements in common with other
556 #symbols = self._symbolunion(other)
557 bset = self._toisl()
558 other = other._toisl()
559 return bool(libisl.isl_set_is_disjoint(bset, other))
560
561 def issubset(self, other):
562 # check if self(bset) is a subset of other
563 symbols = self._symbolunion(other)
564 bset = self._toisl(symbols)
565 other = other._toisl(symbols)
566 return bool(libisl.isl_set_is_strict_subset(other, bset))
567
568 def __le__(self, other):
569 return self.issubset(other)
570
571 def __lt__(self, other):
572 symbols = self._symbolunion(other)
573 bset = self._toisl(symbols)
574 other = other._toisl(symbols)
575 return bool(libisl.isl_set_is_strict_subset(other, bset))
576
577 def issuperset(self, other):
578 # test whether every element in other is in the polyhedron
579 raise NotImplementedError
580
581 def __ge__(self, other):
582 return self.issuperset(other)
583
584 def __gt__(self, other):
585 symbols = self._symbolunion(other)
586 bset = self._toisl(symbols)
587 other = other._toisl(symbols)
588 bool(libisl.isl_set_is_strict_subset(other, bset))
589 raise NotImplementedError
590
591 def union(self, *others):
592 # return a new polyhedron with elements from the polyhedron and all
593 # others (convex union)
594 raise NotImplementedError
595
596 def __or__(self, other):
597 return self.union(other)
598
599 def intersection(self, *others):
600 # return a new polyhedron with elements common to the polyhedron and all
601 # others
602 # a poor man's implementation could be:
603 # equalities = list(self.equalities)
604 # inequalities = list(self.inequalities)
605 # for other in others:
606 # equalities.extend(other.equalities)
607 # inequalities.extend(other.inequalities)
608 # return self.__class__(equalities, inequalities)
609 raise NotImplementedError
610
611 def __and__(self, other):
612 return self.intersection(other)
613
614 def difference(self, other):
615 # return a new polyhedron with elements in the polyhedron that are not in the other
616 symbols = self._symbolunion(other)
617 bset = self._toisl(symbols)
618 other = other._toisl(symbols)
619 difference = libisl.isl_set_subtract(bset, other)
620 return difference
621
622 def __sub__(self, other):
623 return self.difference(other)
624
625 def __str__(self):
626 constraints = []
627 for constraint in self.equalities:
628 constraints.append('{} == 0'.format(constraint))
629 for constraint in self.inequalities:
630 constraints.append('{} >= 0'.format(constraint))
631 return '{{{}}}'.format(', '.join(constraints))
632
633 def __repr__(self):
634 if self.isempty():
635 return 'Empty'
636 elif self.isuniverse():
637 return 'Universe'
638 else:
639 equalities = list(self.equalities)
640 inequalities = list(self.inequalities)
641 return '{}(equalities={!r}, inequalities={!r})' \
642 ''.format(self.__class__.__name__, equalities, inequalities)
643
644 def _symbolunion(self, *others):
645 symbols = set(self.symbols)
646 for other in others:
647 symbols.update(other.symbols)
648 return sorted(symbols)
649
650 def _toisl(self, symbols=None):
651 if symbols is None:
652 symbols = self.symbols
653 dimension = len(symbols)
654 space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
655 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
656 ls = libisl.isl_local_space_from_space(space)
657 for equality in self.equalities:
658 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
659 for symbol, coefficient in equality.coefficients():
660 val = str(coefficient).encode()
661 val = libisl.isl_val_read_from_str(_main_ctx, val)
662 dim = symbols.index(symbol)
663 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
664 if equality.constant != 0:
665 val = str(equality.constant).encode()
666 val = libisl.isl_val_read_from_str(_main_ctx, val)
667 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
668 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
669 for inequality in self.inequalities:
670 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
671 for symbol, coefficient in inequality.coefficients():
672 val = str(coefficient).encode()
673 val = libisl.isl_val_read_from_str(_main_ctx, val)
674 dim = symbols.index(symbol)
675 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
676 if inequality.constant != 0:
677 val = str(inequality.constant).encode()
678 val = libisl.isl_val_read_from_str(_main_ctx, val)
679 cin = libisl.isl_constraint_set_constant_val(cin, val)
680 bset = libisl.isl_basic_set_add_constraint(bset, cin)
681 bset = isl.BasicSet(bset)
682 return bset
683
684 @classmethod
685 def _fromisl(cls, bset, symbols):
686 raise NotImplementedError
687 equalities = ...
688 inequalities = ...
689 return cls(equalities, inequalities)
690 '''takes basic set in isl form and puts back into python version of polyhedron
691 isl example code gives isl form as:
692 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
693 our printer is giving form as:
694 { [i0, i1] : 2i1 >= -2 - i0 } '''
695
696 Empty = eq(0,1)
697
698 Universe = Polyhedron()
699
700
701 if __name__ == '__main__':
702 p1 = Polyhedron('2a + 2b + 1 == 0') # empty
703 print(p1._toisl())
704 p2 = Polyhedron('3x + 2y + 3 == 0') # not empty
705 print(p2._toisl())