Fix Symbol.__slots__
[linpy.git] / pypol / linear.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from fractions import Fraction, gcd
7
8 from pypol import isl
9 from pypol.isl import libisl
10
11
12 __all__ = [
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'Empty', 'Universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = Constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = Constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 _main_ctx = isl.Context()
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 __slots__ = (
54 '_coefficients',
55 '_constant',
56 '_symbols',
57 '_dimension',
58 )
59
60 def __new__(cls, coefficients=None, constant=0):
61 if isinstance(coefficients, str):
62 if constant:
63 raise TypeError('too many arguments')
64 return cls.fromstring(coefficients)
65 if isinstance(coefficients, dict):
66 coefficients = coefficients.items()
67 if coefficients is None:
68 return Constant(constant)
69 coefficients = [(symbol, coefficient)
70 for symbol, coefficient in coefficients if coefficient != 0]
71 if len(coefficients) == 0:
72 return Constant(constant)
73 elif len(coefficients) == 1 and constant == 0:
74 symbol, coefficient = coefficients[0]
75 if coefficient == 1:
76 return Symbol(symbol)
77 self = object().__new__(cls)
78 self._coefficients = {}
79 for symbol, coefficient in coefficients:
80 if isinstance(symbol, Symbol):
81 symbol = symbol.name
82 elif not isinstance(symbol, str):
83 raise TypeError('symbols must be strings or Symbol instances')
84 if isinstance(coefficient, Constant):
85 coefficient = coefficient.constant
86 if not isinstance(coefficient, numbers.Rational):
87 raise TypeError('coefficients must be rational numbers or Constant instances')
88 self._coefficients[symbol] = coefficient
89 if isinstance(constant, Constant):
90 constant = constant.constant
91 if not isinstance(constant, numbers.Rational):
92 raise TypeError('constant must be a rational number or a Constant instance')
93 self._constant = constant
94 self._symbols = tuple(sorted(self._coefficients))
95 self._dimension = len(self._symbols)
96 return self
97
98 @classmethod
99 def _fromast(cls, node):
100 if isinstance(node, ast.Module):
101 assert len(node.body) == 1
102 return cls._fromast(node.body[0])
103 elif isinstance(node, ast.Expr):
104 return cls._fromast(node.value)
105 elif isinstance(node, ast.Name):
106 return Symbol(node.id)
107 elif isinstance(node, ast.Num):
108 return Constant(node.n)
109 elif isinstance(node, ast.UnaryOp):
110 if isinstance(node.op, ast.USub):
111 return -cls._fromast(node.operand)
112 elif isinstance(node, ast.BinOp):
113 left = cls._fromast(node.left)
114 right = cls._fromast(node.right)
115 if isinstance(node.op, ast.Add):
116 return left + right
117 elif isinstance(node.op, ast.Sub):
118 return left - right
119 elif isinstance(node.op, ast.Mult):
120 return left * right
121 elif isinstance(node.op, ast.Div):
122 return left / right
123 raise SyntaxError('invalid syntax')
124
125 @classmethod
126 def fromstring(cls, string):
127 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
128 tree = ast.parse(string, 'eval')
129 return cls._fromast(tree)
130
131 @property
132 def symbols(self):
133 return self._symbols
134
135 @property
136 def dimension(self):
137 return self._dimension
138
139 def coefficient(self, symbol):
140 if isinstance(symbol, Symbol):
141 symbol = str(symbol)
142 elif not isinstance(symbol, str):
143 raise TypeError('symbol must be a string or a Symbol instance')
144 try:
145 return self._coefficients[symbol]
146 except KeyError:
147 return 0
148
149 __getitem__ = coefficient
150
151 def coefficients(self):
152 for symbol in self.symbols:
153 yield symbol, self.coefficient(symbol)
154
155 @property
156 def constant(self):
157 return self._constant
158
159 def isconstant(self):
160 return False
161
162 def values(self):
163 for symbol in self.symbols:
164 yield self.coefficient(symbol)
165 yield self.constant
166
167 def issymbol(self):
168 return False
169
170 def __bool__(self):
171 return True
172
173 def __pos__(self):
174 return self
175
176 def __neg__(self):
177 return self * -1
178
179 @_polymorphic_method
180 def __add__(self, other):
181 coefficients = dict(self.coefficients())
182 for symbol, coefficient in other.coefficients():
183 if symbol in coefficients:
184 coefficients[symbol] += coefficient
185 else:
186 coefficients[symbol] = coefficient
187 constant = self.constant + other.constant
188 return Expression(coefficients, constant)
189
190 __radd__ = __add__
191
192 @_polymorphic_method
193 def __sub__(self, other):
194 coefficients = dict(self.coefficients())
195 for symbol, coefficient in other.coefficients():
196 if symbol in coefficients:
197 coefficients[symbol] -= coefficient
198 else:
199 coefficients[symbol] = -coefficient
200 constant = self.constant - other.constant
201 return Expression(coefficients, constant)
202
203 def __rsub__(self, other):
204 return -(self - other)
205
206 @_polymorphic_method
207 def __mul__(self, other):
208 if other.isconstant():
209 coefficients = dict(self.coefficients())
210 for symbol in coefficients:
211 coefficients[symbol] *= other.constant
212 constant = self.constant * other.constant
213 return Expression(coefficients, constant)
214 if isinstance(other, Expression) and not self.isconstant():
215 raise ValueError('non-linear expression: '
216 '{} * {}'.format(self._parenstr(), other._parenstr()))
217 return NotImplemented
218
219 __rmul__ = __mul__
220
221 @_polymorphic_method
222 def __truediv__(self, other):
223 if other.isconstant():
224 coefficients = dict(self.coefficients())
225 for symbol in coefficients:
226 coefficients[symbol] = \
227 Fraction(coefficients[symbol], other.constant)
228 constant = Fraction(self.constant, other.constant)
229 return Expression(coefficients, constant)
230 if isinstance(other, Expression):
231 raise ValueError('non-linear expression: '
232 '{} / {}'.format(self._parenstr(), other._parenstr()))
233 return NotImplemented
234
235 def __rtruediv__(self, other):
236 if isinstance(other, self):
237 if self.isconstant():
238 constant = Fraction(other, self.constant)
239 return Expression(constant=constant)
240 else:
241 raise ValueError('non-linear expression: '
242 '{} / {}'.format(other._parenstr(), self._parenstr()))
243 return NotImplemented
244
245 def __str__(self):
246 string = ''
247 i = 0
248 for symbol in self.symbols:
249 coefficient = self.coefficient(symbol)
250 if coefficient == 1:
251 if i == 0:
252 string += symbol
253 else:
254 string += ' + {}'.format(symbol)
255 elif coefficient == -1:
256 if i == 0:
257 string += '-{}'.format(symbol)
258 else:
259 string += ' - {}'.format(symbol)
260 else:
261 if i == 0:
262 string += '{}*{}'.format(coefficient, symbol)
263 elif coefficient > 0:
264 string += ' + {}*{}'.format(coefficient, symbol)
265 else:
266 assert coefficient < 0
267 coefficient *= -1
268 string += ' - {}*{}'.format(coefficient, symbol)
269 i += 1
270 constant = self.constant
271 if constant != 0 and i == 0:
272 string += '{}'.format(constant)
273 elif constant > 0:
274 string += ' + {}'.format(constant)
275 elif constant < 0:
276 constant *= -1
277 string += ' - {}'.format(constant)
278 if string == '':
279 string = '0'
280 return string
281
282 def _parenstr(self, always=False):
283 string = str(self)
284 if not always and (self.isconstant() or self.issymbol()):
285 return string
286 else:
287 return '({})'.format(string)
288
289 def __repr__(self):
290 string = '{}({{'.format(self.__class__.__name__)
291 for i, (symbol, coefficient) in enumerate(self.coefficients()):
292 if i != 0:
293 string += ', '
294 string += '{!r}: {!r}'.format(symbol, coefficient)
295 string += '}}, {!r})'.format(self.constant)
296 return string
297
298 @_polymorphic_method
299 def __eq__(self, other):
300 # "normal" equality
301 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
302 return isinstance(other, Expression) and \
303 self._coefficients == other._coefficients and \
304 self.constant == other.constant
305
306 def __hash__(self):
307 return hash((tuple(sorted(self._coefficients.items())), self._constant))
308
309 def _toint(self):
310 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
311 [value.denominator for value in self.values()])
312 return self * lcm
313
314 @_polymorphic_method
315 def _eq(self, other):
316 return Polyhedron(equalities=[(self - other)._toint()])
317
318 @_polymorphic_method
319 def __le__(self, other):
320 return Polyhedron(inequalities=[(other - self)._toint()])
321
322 @_polymorphic_method
323 def __lt__(self, other):
324 return Polyhedron(inequalities=[(other - self)._toint() - 1])
325
326 @_polymorphic_method
327 def __ge__(self, other):
328 return Polyhedron(inequalities=[(self - other)._toint()])
329
330 @_polymorphic_method
331 def __gt__(self, other):
332 return Polyhedron(inequalities=[(self - other)._toint() - 1])
333
334
335 class Constant(Expression):
336
337 def __new__(cls, numerator=0, denominator=None):
338 self = object().__new__(cls)
339 if denominator is None:
340 if isinstance(numerator, numbers.Rational):
341 self._constant = numerator
342 elif isinstance(numerator, Constant):
343 self._constant = numerator.constant
344 else:
345 raise TypeError('constant must be a rational number or a Constant instance')
346 else:
347 self._constant = Fraction(numerator, denominator)
348 self._coefficients = {}
349 self._symbols = ()
350 self._dimension = 0
351 return self
352
353 def isconstant(self):
354 return True
355
356 def __bool__(self):
357 return bool(self.constant)
358
359 def __repr__(self):
360 return '{}({!r})'.format(self.__class__.__name__, self._constant)
361
362
363 class Symbol(Expression):
364
365 __slots__ = Expression.__slots__ + (
366 '_name',
367 )
368
369 def __new__(cls, name):
370 if isinstance(name, Symbol):
371 name = name.name
372 elif not isinstance(name, str):
373 raise TypeError('name must be a string or a Symbol instance')
374 self = object().__new__(cls)
375 self._coefficients = {name: 1}
376 self._constant = 0
377 self._symbols = tuple(name)
378 self._name = name
379 self._dimension = 1
380 return self
381
382 @property
383 def name(self):
384 return self._name
385
386 def issymbol(self):
387 return True
388
389 def __repr__(self):
390 return '{}({!r})'.format(self.__class__.__name__, self._name)
391
392 def symbols(names):
393 if isinstance(names, str):
394 names = names.replace(',', ' ').split()
395 return (Symbol(name) for name in names)
396
397
398 @_polymorphic_operator
399 def eq(a, b):
400 return a.__eq__(b)
401
402 @_polymorphic_operator
403 def le(a, b):
404 return a.__le__(b)
405
406 @_polymorphic_operator
407 def lt(a, b):
408 return a.__lt__(b)
409
410 @_polymorphic_operator
411 def ge(a, b):
412 return a.__ge__(b)
413
414 @_polymorphic_operator
415 def gt(a, b):
416 return a.__gt__(b)
417
418
419 class Polyhedron:
420 """
421 This class implements polyhedrons.
422 """
423
424 __slots__ = (
425 '_equalities',
426 '_inequalities',
427 '_constraints',
428 '_symbols',
429 )
430
431 def __new__(cls, equalities=None, inequalities=None):
432 if isinstance(equalities, str):
433 if inequalities is not None:
434 raise TypeError('too many arguments')
435 return cls.fromstring(equalities)
436 self = super().__new__(cls)
437 self._equalities = []
438 if equalities is not None:
439 for constraint in equalities:
440 for value in constraint.values():
441 if value.denominator != 1:
442 raise TypeError('non-integer constraint: '
443 '{} == 0'.format(constraint))
444 self._equalities.append(constraint)
445 self._equalities = tuple(self._equalities)
446 self._inequalities = []
447 if inequalities is not None:
448 for constraint in inequalities:
449 for value in constraint.values():
450 if value.denominator != 1:
451 raise TypeError('non-integer constraint: '
452 '{} <= 0'.format(constraint))
453 self._inequalities.append(constraint)
454 self._inequalities = tuple(self._inequalities)
455 self._constraints = self._equalities + self._inequalities
456 self._symbols = set()
457 for constraint in self._constraints:
458 self.symbols.update(constraint.symbols)
459 self._symbols = tuple(sorted(self._symbols))
460 return self
461
462 @classmethod
463 def fromstring(cls, string):
464 string = string.strip()
465 string = re.sub(r'^\{\s*|\s*\}$', '', string)
466 string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
467 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
468 equalities = []
469 inequalities = []
470 for cstr in re.split(r',|;|and|&&|/\\|∧', string, flags=re.I):
471 tree = ast.parse(cstr.strip(), 'eval')
472 if not isinstance(tree, ast.Module) or len(tree.body) != 1:
473 raise SyntaxError('invalid syntax')
474 node = tree.body[0]
475 if not isinstance(node, ast.Expr):
476 raise SyntaxError('invalid syntax')
477 node = node.value
478 if not isinstance(node, ast.Compare):
479 raise SyntaxError('invalid syntax')
480 left = Expression._fromast(node.left)
481 for i in range(len(node.ops)):
482 op = node.ops[i]
483 right = Expression._fromast(node.comparators[i])
484 if isinstance(op, ast.Lt):
485 inequalities.append(right - left - 1)
486 elif isinstance(op, ast.LtE):
487 inequalities.append(right - left)
488 elif isinstance(op, ast.Eq):
489 equalities.append(left - right)
490 elif isinstance(op, ast.GtE):
491 inequalities.append(left - right)
492 elif isinstance(op, ast.Gt):
493 inequalities.append(left - right - 1)
494 else:
495 raise SyntaxError('invalid syntax')
496 left = right
497 return cls(equalities, inequalities)
498
499 @property
500 def equalities(self):
501 return self._equalities
502
503 @property
504 def inequalities(self):
505 return self._inequalities
506
507 @property
508 def constraints(self):
509 return self._constraints
510
511 @property
512 def symbols(self):
513 return self._symbols
514
515 @property
516 def dimension(self):
517 return len(self.symbols)
518
519 def __bool__(self):
520 return not self.is_empty()
521
522 def __contains__(self, value):
523 # is the value in the polyhedron?
524 raise NotImplementedError
525
526 def __eq__(self, other):
527 # works correctly when symbols is not passed
528 # should be equal if values are the same even if symbols are different
529 bset = self._toisl()
530 other = other._toisl()
531 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
532
533 def isempty(self):
534 bset = self._toisl()
535 return bool(libisl.isl_basic_set_is_empty(bset))
536
537 def isuniverse(self):
538 bset = self._toisl()
539 return bool(libisl.isl_basic_set_is_universe(bset))
540
541 def isdisjoint(self, other):
542 # return true if the polyhedron has no elements in common with other
543 #symbols = self._symbolunion(other)
544 bset = self._toisl()
545 other = other._toisl()
546 return bool(libisl.isl_set_is_disjoint(bset, other))
547
548 def issubset(self, other):
549 # check if self(bset) is a subset of other
550 symbols = self._symbolunion(other)
551 bset = self._toisl(symbols)
552 other = other._toisl(symbols)
553 return bool(libisl.isl_set_is_strict_subset(other, bset))
554
555 def __le__(self, other):
556 return self.issubset(other)
557
558 def __lt__(self, other):
559 symbols = self._symbolunion(other)
560 bset = self._toisl(symbols)
561 other = other._toisl(symbols)
562 return bool(libisl.isl_set_is_strict_subset(other, bset))
563
564 def issuperset(self, other):
565 # test whether every element in other is in the polyhedron
566 raise NotImplementedError
567
568 def __ge__(self, other):
569 return self.issuperset(other)
570
571 def __gt__(self, other):
572 symbols = self._symbolunion(other)
573 bset = self._toisl(symbols)
574 other = other._toisl(symbols)
575 bool(libisl.isl_set_is_strict_subset(other, bset))
576 raise NotImplementedError
577
578 def union(self, *others):
579 # return a new polyhedron with elements from the polyhedron and all
580 # others (convex union)
581 raise NotImplementedError
582
583 def __or__(self, other):
584 return self.union(other)
585
586 def intersection(self, *others):
587 # return a new polyhedron with elements common to the polyhedron and all
588 # others
589 # a poor man's implementation could be:
590 # equalities = list(self.equalities)
591 # inequalities = list(self.inequalities)
592 # for other in others:
593 # equalities.extend(other.equalities)
594 # inequalities.extend(other.inequalities)
595 # return self.__class__(equalities, inequalities)
596 raise NotImplementedError
597
598 def __and__(self, other):
599 return self.intersection(other)
600
601 def difference(self, other):
602 # return a new polyhedron with elements in the polyhedron that are not in the other
603 symbols = self._symbolunion(other)
604 bset = self._toisl(symbols)
605 other = other._toisl(symbols)
606 difference = libisl.isl_set_subtract(bset, other)
607 return difference
608
609 def __sub__(self, other):
610 return self.difference(other)
611
612 def __str__(self):
613 constraints = []
614 for constraint in self.equalities:
615 constraints.append('{} == 0'.format(constraint))
616 for constraint in self.inequalities:
617 constraints.append('{} >= 0'.format(constraint))
618 return '{{{}}}'.format(', '.join(constraints))
619
620 def __repr__(self):
621 if self.isempty():
622 return 'Empty'
623 elif self.isuniverse():
624 return 'Universe'
625 else:
626 equalities = list(self.equalities)
627 inequalities = list(self.inequalities)
628 return '{}(equalities={!r}, inequalities={!r})' \
629 ''.format(self.__class__.__name__, equalities, inequalities)
630
631 def _symbolunion(self, *others):
632 symbols = set(self.symbols)
633 for other in others:
634 symbols.update(other.symbols)
635 return sorted(symbols)
636
637 def _toisl(self, symbols=None):
638 if symbols is None:
639 symbols = self.symbols
640 dimension = len(symbols)
641 space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
642 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
643 ls = libisl.isl_local_space_from_space(space)
644 for equality in self.equalities:
645 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
646 for symbol, coefficient in equality.coefficients():
647 val = str(coefficient).encode()
648 val = libisl.isl_val_read_from_str(_main_ctx, val)
649 dim = symbols.index(symbol)
650 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
651 if equality.constant != 0:
652 val = str(equality.constant).encode()
653 val = libisl.isl_val_read_from_str(_main_ctx, val)
654 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
655 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
656 for inequality in self.inequalities:
657 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
658 for symbol, coefficient in inequality.coefficients():
659 val = str(coefficient).encode()
660 val = libisl.isl_val_read_from_str(_main_ctx, val)
661 dim = symbols.index(symbol)
662 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
663 if inequality.constant != 0:
664 val = str(inequality.constant).encode()
665 val = libisl.isl_val_read_from_str(_main_ctx, val)
666 cin = libisl.isl_constraint_set_constant_val(cin, val)
667 bset = libisl.isl_basic_set_add_constraint(bset, cin)
668 bset = isl.BasicSet(bset)
669 return bset
670
671 @classmethod
672 def _fromisl(cls, bset, symbols):
673 raise NotImplementedError
674 equalities = ...
675 inequalities = ...
676 return cls(equalities, inequalities)
677 '''takes basic set in isl form and puts back into python version of polyhedron
678 isl example code gives isl form as:
679 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
680 our printer is giving form as:
681 { [i0, i1] : 2i1 >= -2 - i0 } '''
682
683 Empty = eq(0,1)
684 Universe = Polyhedron()
685
686 if __name__ == '__main__':
687 p1 = Polyhedron('2a + 2b + 1 == 0') # empty
688 print(p1._toisl())
689 p2 = Polyhedron('3x + 2y + 3 == 0') # not empty
690 print(p2._toisl())