0712e1e238a9d1a9d2ec0a473317ffcca90c9179
[linpy.git] / pypol / linear.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from fractions import Fraction, gcd
7
8 from pypol import isl
9 from pypol.isl import libisl
10
11
12 __all__ = [
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'Empty', 'Universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = Constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = Constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 _main_ctx = isl.Context()
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 def __new__(cls, coefficients=None, constant=0):
54 if isinstance(coefficients, str):
55 if constant:
56 raise TypeError('too many arguments')
57 return cls.fromstring(coefficients)
58 if isinstance(coefficients, dict):
59 coefficients = coefficients.items()
60 if coefficients is None:
61 return Constant(constant)
62 coefficients = [(symbol, coefficient)
63 for symbol, coefficient in coefficients if coefficient != 0]
64 if len(coefficients) == 0:
65 return Constant(constant)
66 elif len(coefficients) == 1 and constant == 0:
67 symbol, coefficient = coefficients[0]
68 if coefficient == 1:
69 return Symbol(symbol)
70 self = object().__new__(cls)
71 self._coefficients = {}
72 for symbol, coefficient in coefficients:
73 if isinstance(symbol, Symbol):
74 symbol = symbol.name
75 elif not isinstance(symbol, str):
76 raise TypeError('symbols must be strings or Symbol instances')
77 if isinstance(coefficient, Constant):
78 coefficient = coefficient.constant
79 if not isinstance(coefficient, numbers.Rational):
80 raise TypeError('coefficients must be rational numbers or Constant instances')
81 self._coefficients[symbol] = coefficient
82 if isinstance(constant, Constant):
83 constant = constant.constant
84 if not isinstance(constant, numbers.Rational):
85 raise TypeError('constant must be a rational number or a Constant instance')
86 self._constant = constant
87 self._symbols = tuple(sorted(self._coefficients))
88 self._dimension = len(self._symbols)
89 return self
90
91 @classmethod
92 def _fromast(cls, node):
93 if isinstance(node, ast.Module):
94 assert len(node.body) == 1
95 return cls._fromast(node.body[0])
96 elif isinstance(node, ast.Expr):
97 return cls._fromast(node.value)
98 elif isinstance(node, ast.Name):
99 return Symbol(node.id)
100 elif isinstance(node, ast.Num):
101 return Constant(node.n)
102 elif isinstance(node, ast.UnaryOp):
103 if isinstance(node.op, ast.USub):
104 return -cls._fromast(node.operand)
105 elif isinstance(node, ast.BinOp):
106 left = cls._fromast(node.left)
107 right = cls._fromast(node.right)
108 if isinstance(node.op, ast.Add):
109 return left + right
110 elif isinstance(node.op, ast.Sub):
111 return left - right
112 elif isinstance(node.op, ast.Mult):
113 return left * right
114 elif isinstance(node.op, ast.Div):
115 return left / right
116 raise SyntaxError('invalid syntax')
117
118 @classmethod
119 def fromstring(cls, string):
120 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
121 tree = ast.parse(string, 'eval')
122 return cls._fromast(tree)
123
124 @property
125 def symbols(self):
126 return self._symbols
127
128 @property
129 def dimension(self):
130 return self._dimension
131
132 def coefficient(self, symbol):
133 if isinstance(symbol, Symbol):
134 symbol = str(symbol)
135 elif not isinstance(symbol, str):
136 raise TypeError('symbol must be a string or a Symbol instance')
137 try:
138 return self._coefficients[symbol]
139 except KeyError:
140 return 0
141
142 __getitem__ = coefficient
143
144 def coefficients(self):
145 for symbol in self.symbols:
146 yield symbol, self.coefficient(symbol)
147
148 @property
149 def constant(self):
150 return self._constant
151
152 def isconstant(self):
153 return False
154
155 def values(self):
156 for symbol in self.symbols:
157 yield self.coefficient(symbol)
158 yield self.constant
159
160 def issymbol(self):
161 return False
162
163 def __bool__(self):
164 return True
165
166 def __pos__(self):
167 return self
168
169 def __neg__(self):
170 return self * -1
171
172 @_polymorphic_method
173 def __add__(self, other):
174 coefficients = dict(self.coefficients())
175 for symbol, coefficient in other.coefficients():
176 if symbol in coefficients:
177 coefficients[symbol] += coefficient
178 else:
179 coefficients[symbol] = coefficient
180 constant = self.constant + other.constant
181 return Expression(coefficients, constant)
182
183 __radd__ = __add__
184
185 @_polymorphic_method
186 def __sub__(self, other):
187 coefficients = dict(self.coefficients())
188 for symbol, coefficient in other.coefficients():
189 if symbol in coefficients:
190 coefficients[symbol] -= coefficient
191 else:
192 coefficients[symbol] = -coefficient
193 constant = self.constant - other.constant
194 return Expression(coefficients, constant)
195
196 def __rsub__(self, other):
197 return -(self - other)
198
199 @_polymorphic_method
200 def __mul__(self, other):
201 if other.isconstant():
202 coefficients = dict(self.coefficients())
203 for symbol in coefficients:
204 coefficients[symbol] *= other.constant
205 constant = self.constant * other.constant
206 return Expression(coefficients, constant)
207 if isinstance(other, Expression) and not self.isconstant():
208 raise ValueError('non-linear expression: '
209 '{} * {}'.format(self._parenstr(), other._parenstr()))
210 return NotImplemented
211
212 __rmul__ = __mul__
213
214 @_polymorphic_method
215 def __truediv__(self, other):
216 if other.isconstant():
217 coefficients = dict(self.coefficients())
218 for symbol in coefficients:
219 coefficients[symbol] = \
220 Fraction(coefficients[symbol], other.constant)
221 constant = Fraction(self.constant, other.constant)
222 return Expression(coefficients, constant)
223 if isinstance(other, Expression):
224 raise ValueError('non-linear expression: '
225 '{} / {}'.format(self._parenstr(), other._parenstr()))
226 return NotImplemented
227
228 def __rtruediv__(self, other):
229 if isinstance(other, self):
230 if self.isconstant():
231 constant = Fraction(other, self.constant)
232 return Expression(constant=constant)
233 else:
234 raise ValueError('non-linear expression: '
235 '{} / {}'.format(other._parenstr(), self._parenstr()))
236 return NotImplemented
237
238 def __str__(self):
239 string = ''
240 i = 0
241 for symbol in self.symbols:
242 coefficient = self.coefficient(symbol)
243 if coefficient == 1:
244 if i == 0:
245 string += symbol
246 else:
247 string += ' + {}'.format(symbol)
248 elif coefficient == -1:
249 if i == 0:
250 string += '-{}'.format(symbol)
251 else:
252 string += ' - {}'.format(symbol)
253 else:
254 if i == 0:
255 string += '{}*{}'.format(coefficient, symbol)
256 elif coefficient > 0:
257 string += ' + {}*{}'.format(coefficient, symbol)
258 else:
259 assert coefficient < 0
260 coefficient *= -1
261 string += ' - {}*{}'.format(coefficient, symbol)
262 i += 1
263 constant = self.constant
264 if constant != 0 and i == 0:
265 string += '{}'.format(constant)
266 elif constant > 0:
267 string += ' + {}'.format(constant)
268 elif constant < 0:
269 constant *= -1
270 string += ' - {}'.format(constant)
271 if string == '':
272 string = '0'
273 return string
274
275 def _parenstr(self, always=False):
276 string = str(self)
277 if not always and (self.isconstant() or self.issymbol()):
278 return string
279 else:
280 return '({})'.format(string)
281
282 def __repr__(self):
283 string = '{}({{'.format(self.__class__.__name__)
284 for i, (symbol, coefficient) in enumerate(self.coefficients()):
285 if i != 0:
286 string += ', '
287 string += '{!r}: {!r}'.format(symbol, coefficient)
288 string += '}}, {!r})'.format(self.constant)
289 return string
290
291 @_polymorphic_method
292 def __eq__(self, other):
293 # "normal" equality
294 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
295 return isinstance(other, Expression) and \
296 self._coefficients == other._coefficients and \
297 self.constant == other.constant
298
299 def __hash__(self):
300 return hash((tuple(sorted(self._coefficients.items())), self._constant))
301
302 def _toint(self):
303 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
304 [value.denominator for value in self.values()])
305 return self * lcm
306
307 @_polymorphic_method
308 def _eq(self, other):
309 return Polyhedron(equalities=[(self - other)._toint()])
310
311 @_polymorphic_method
312 def __le__(self, other):
313 return Polyhedron(inequalities=[(other - self)._toint()])
314
315 @_polymorphic_method
316 def __lt__(self, other):
317 return Polyhedron(inequalities=[(other - self)._toint() - 1])
318
319 @_polymorphic_method
320 def __ge__(self, other):
321 return Polyhedron(inequalities=[(self - other)._toint()])
322
323 @_polymorphic_method
324 def __gt__(self, other):
325 return Polyhedron(inequalities=[(self - other)._toint() - 1])
326
327
328 class Constant(Expression):
329
330 def __new__(cls, numerator=0, denominator=None):
331 self = object().__new__(cls)
332 if denominator is None:
333 if isinstance(numerator, numbers.Rational):
334 self._constant = numerator
335 elif isinstance(numerator, Constant):
336 self._constant = numerator.constant
337 else:
338 raise TypeError('constant must be a rational number or a Constant instance')
339 else:
340 self._constant = Fraction(numerator, denominator)
341 self._coefficients = {}
342 self._symbols = ()
343 self._dimension = 0
344 return self
345
346 def isconstant(self):
347 return True
348
349 def __bool__(self):
350 return bool(self.constant)
351
352 def __repr__(self):
353 return '{}({!r})'.format(self.__class__.__name__, self._constant)
354
355
356 class Symbol(Expression):
357
358 def __new__(cls, name):
359 if isinstance(name, Symbol):
360 name = name.name
361 elif not isinstance(name, str):
362 raise TypeError('name must be a string or a Symbol instance')
363 self = object().__new__(cls)
364 self._coefficients = {name: 1}
365 self._constant = 0
366 self._symbols = tuple(name)
367 self._name = name
368 self._dimension = 1
369 return self
370
371 @property
372 def name(self):
373 return self._name
374
375 def issymbol(self):
376 return True
377
378 def __repr__(self):
379 return '{}({!r})'.format(self.__class__.__name__, self._name)
380
381 def symbols(names):
382 if isinstance(names, str):
383 names = names.replace(',', ' ').split()
384 return (Symbol(name) for name in names)
385
386
387 @_polymorphic_operator
388 def eq(a, b):
389 return a.__eq__(b)
390
391 @_polymorphic_operator
392 def le(a, b):
393 return a.__le__(b)
394
395 @_polymorphic_operator
396 def lt(a, b):
397 return a.__lt__(b)
398
399 @_polymorphic_operator
400 def ge(a, b):
401 return a.__ge__(b)
402
403 @_polymorphic_operator
404 def gt(a, b):
405 return a.__gt__(b)
406
407
408 class Polyhedron:
409 """
410 This class implements polyhedrons.
411 """
412
413 def __new__(cls, equalities=None, inequalities=None):
414 if isinstance(equalities, str):
415 if inequalities is not None:
416 raise TypeError('too many arguments')
417 return cls.fromstring(equalities)
418 self = super().__new__(cls)
419 self._equalities = []
420 if equalities is not None:
421 for constraint in equalities:
422 for value in constraint.values():
423 if value.denominator != 1:
424 raise TypeError('non-integer constraint: '
425 '{} == 0'.format(constraint))
426 self._equalities.append(constraint)
427 self._equalities = tuple(self._equalities)
428 self._inequalities = []
429 if inequalities is not None:
430 for constraint in inequalities:
431 for value in constraint.values():
432 if value.denominator != 1:
433 raise TypeError('non-integer constraint: '
434 '{} <= 0'.format(constraint))
435 self._inequalities.append(constraint)
436 self._inequalities = tuple(self._inequalities)
437 self._constraints = self._equalities + self._inequalities
438 self._symbols = set()
439 for constraint in self._constraints:
440 self.symbols.update(constraint.symbols)
441 self._symbols = tuple(sorted(self._symbols))
442 return self
443
444 @classmethod
445 def fromstring(cls, string):
446 string = string.strip()
447 string = re.sub(r'^\{\s*|\s*\}$', '', string)
448 string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
449 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
450 equalities = []
451 inequalities = []
452 for cstr in re.split(r',|;|and|&&|/\\|∧', string, flags=re.I):
453 tree = ast.parse(cstr.strip(), 'eval')
454 if not isinstance(tree, ast.Module) or len(tree.body) != 1:
455 raise SyntaxError('invalid syntax')
456 node = tree.body[0]
457 if not isinstance(node, ast.Expr):
458 raise SyntaxError('invalid syntax')
459 node = node.value
460 if not isinstance(node, ast.Compare):
461 raise SyntaxError('invalid syntax')
462 left = Expression._fromast(node.left)
463 for i in range(len(node.ops)):
464 op = node.ops[i]
465 right = Expression._fromast(node.comparators[i])
466 if isinstance(op, ast.Lt):
467 inequalities.append(right - left - 1)
468 elif isinstance(op, ast.LtE):
469 inequalities.append(right - left)
470 elif isinstance(op, ast.Eq):
471 equalities.append(left - right)
472 elif isinstance(op, ast.GtE):
473 inequalities.append(left - right)
474 elif isinstance(op, ast.Gt):
475 inequalities.append(left - right - 1)
476 else:
477 raise SyntaxError('invalid syntax')
478 left = right
479 return cls(equalities, inequalities)
480
481 @property
482 def equalities(self):
483 return self._equalities
484
485 @property
486 def inequalities(self):
487 return self._inequalities
488
489 @property
490 def constraints(self):
491 return self._constraints
492
493 @property
494 def symbols(self):
495 return self._symbols
496
497 @property
498 def dimension(self):
499 return len(self.symbols)
500
501 def __bool__(self):
502 return not self.is_empty()
503
504 def __contains__(self, value):
505 # is the value in the polyhedron?
506 raise NotImplementedError
507
508 def __eq__(self, other):
509 # works correctly when symbols is not passed
510 # should be equal if values are the same even if symbols are different
511 bset = self._toisl()
512 other = other._toisl()
513 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
514
515 def isempty(self):
516 bset = self._toisl()
517 return bool(libisl.isl_basic_set_is_empty(bset))
518
519 def isuniverse(self):
520 bset = self._toisl()
521 return bool(libisl.isl_basic_set_is_universe(bset))
522
523 def isdisjoint(self, other):
524 # return true if the polyhedron has no elements in common with other
525 #symbols = self._symbolunion(other)
526 bset = self._toisl()
527 other = other._toisl()
528 return bool(libisl.isl_set_is_disjoint(bset, other))
529
530 def issubset(self, other):
531 # check if self(bset) is a subset of other
532 symbols = self._symbolunion(other)
533 bset = self._toisl(symbols)
534 other = other._toisl(symbols)
535 return bool(libisl.isl_set_is_strict_subset(other, bset))
536
537 def __le__(self, other):
538 return self.issubset(other)
539
540 def __lt__(self, other):
541 symbols = self._symbolunion(other)
542 bset = self._toisl(symbols)
543 other = other._toisl(symbols)
544 return bool(libisl.isl_set_is_strict_subset(other, bset))
545
546 def issuperset(self, other):
547 # test whether every element in other is in the polyhedron
548 raise NotImplementedError
549
550 def __ge__(self, other):
551 return self.issuperset(other)
552
553 def __gt__(self, other):
554 symbols = self._symbolunion(other)
555 bset = self._toisl(symbols)
556 other = other._toisl(symbols)
557 bool(libisl.isl_set_is_strict_subset(other, bset))
558 raise NotImplementedError
559
560 def union(self, *others):
561 # return a new polyhedron with elements from the polyhedron and all
562 # others (convex union)
563 raise NotImplementedError
564
565 def __or__(self, other):
566 return self.union(other)
567
568 def intersection(self, *others):
569 # return a new polyhedron with elements common to the polyhedron and all
570 # others
571 # a poor man's implementation could be:
572 # equalities = list(self.equalities)
573 # inequalities = list(self.inequalities)
574 # for other in others:
575 # equalities.extend(other.equalities)
576 # inequalities.extend(other.inequalities)
577 # return self.__class__(equalities, inequalities)
578 raise NotImplementedError
579
580 def __and__(self, other):
581 return self.intersection(other)
582
583 def difference(self, other):
584 # return a new polyhedron with elements in the polyhedron that are not in the other
585 symbols = self._symbolunion(other)
586 bset = self._toisl(symbols)
587 other = other._toisl(symbols)
588 difference = libisl.isl_set_subtract(bset, other)
589 return difference
590
591 def __sub__(self, other):
592 return self.difference(other)
593
594 def __str__(self):
595 constraints = []
596 for constraint in self.equalities:
597 constraints.append('{} == 0'.format(constraint))
598 for constraint in self.inequalities:
599 constraints.append('{} >= 0'.format(constraint))
600 return '{{{}}}'.format(', '.join(constraints))
601
602 def __repr__(self):
603 if self.isempty():
604 return 'Empty'
605 elif self.isuniverse():
606 return 'Universe'
607 else:
608 equalities = list(self.equalities)
609 inequalities = list(self.inequalities)
610 return '{}(equalities={!r}, inequalities={!r})' \
611 ''.format(self.__class__.__name__, equalities, inequalities)
612
613 def _symbolunion(self, *others):
614 symbols = set(self.symbols)
615 for other in others:
616 symbols.update(other.symbols)
617 return sorted(symbols)
618
619 def _toisl(self, symbols=None):
620 if symbols is None:
621 symbols = self.symbols
622 dimension = len(symbols)
623 space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
624 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
625 ls = libisl.isl_local_space_from_space(space)
626 for equality in self.equalities:
627 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
628 for symbol, coefficient in equality.coefficients():
629 val = str(coefficient).encode()
630 val = libisl.isl_val_read_from_str(_main_ctx, val)
631 dim = symbols.index(symbol)
632 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
633 if equality.constant != 0:
634 val = str(equality.constant).encode()
635 val = libisl.isl_val_read_from_str(_main_ctx, val)
636 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
637 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
638 for inequality in self.inequalities:
639 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
640 for symbol, coefficient in inequality.coefficients():
641 val = str(coefficient).encode()
642 val = libisl.isl_val_read_from_str(_main_ctx, val)
643 dim = symbols.index(symbol)
644 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
645 if inequality.constant != 0:
646 val = str(inequality.constant).encode()
647 val = libisl.isl_val_read_from_str(_main_ctx, val)
648 cin = libisl.isl_constraint_set_constant_val(cin, val)
649 bset = libisl.isl_basic_set_add_constraint(bset, cin)
650 bset = isl.BasicSet(bset)
651 return bset
652
653 @classmethod
654 def _fromisl(cls, bset, symbols):
655 raise NotImplementedError
656 equalities = ...
657 inequalities = ...
658 return cls(equalities, inequalities)
659 '''takes basic set in isl form and puts back into python version of polyhedron
660 isl example code gives isl form as:
661 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
662 our printer is giving form as:
663 { [i0, i1] : 2i1 >= -2 - i0 } '''
664
665 Empty = eq(0,1)
666 Universe = Polyhedron()
667
668 if __name__ == '__main__':
669 p1 = Polyhedron('2a + 2b + 1 == 0') # empty
670 print(p1._toisl())
671 p2 = Polyhedron('3x + 2y + 3 == 0') # not empty
672 print(p2._toisl())