845fac3596cc53e870a30004baa26bdb59667f66
[linpy.git] / pypol / linear.py
1 import ctypes, ctypes.util
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7 from . import isl
8 from .isl import libisl
9
10
11 __all__ = [
12 'Expression', 'Constant', 'Symbol', 'symbols',
13 'eq', 'le', 'lt', 'ge', 'gt',
14 'Polyhedron',
15 'empty', 'universe'
16 ]
17
18
19 def _polymorphic_method(func):
20 @functools.wraps(func)
21 def wrapper(a, b):
22 if isinstance(b, Expression):
23 return func(a, b)
24 if isinstance(b, numbers.Rational):
25 b = Constant(b)
26 return func(a, b)
27 return NotImplemented
28 return wrapper
29
30 def _polymorphic_operator(func):
31 # A polymorphic operator should call a polymorphic method, hence we just
32 # have to test the left operand.
33 @functools.wraps(func)
34 def wrapper(a, b):
35 if isinstance(a, numbers.Rational):
36 a = Constant(a)
37 return func(a, b)
38 elif isinstance(a, Expression):
39 return func(a, b)
40 raise TypeError('arguments must be linear expressions')
41 return wrapper
42
43
44 _main_ctx = isl.Context()
45
46
47 class Expression:
48 """
49 This class implements linear expressions.
50 """
51
52 def __new__(cls, coefficients=None, constant=0):
53 if isinstance(coefficients, str):
54 if constant:
55 raise TypeError('too many arguments')
56 return cls.fromstring(coefficients)
57 if isinstance(coefficients, dict):
58 coefficients = coefficients.items()
59 if coefficients is None:
60 return Constant(constant)
61 coefficients = [(symbol, coefficient)
62 for symbol, coefficient in coefficients if coefficient != 0]
63 if len(coefficients) == 0:
64 return Constant(constant)
65 elif len(coefficients) == 1 and constant == 0:
66 symbol, coefficient = coefficients[0]
67 if coefficient == 1:
68 return Symbol(symbol)
69 self = object().__new__(cls)
70 self._coefficients = {}
71 for symbol, coefficient in coefficients:
72 if isinstance(symbol, Symbol):
73 symbol = str(symbol)
74 elif not isinstance(symbol, str):
75 raise TypeError('symbols must be strings or Symbol instances')
76 if isinstance(coefficient, Constant):
77 coefficient = coefficient.constant
78 if not isinstance(coefficient, numbers.Rational):
79 raise TypeError('coefficients must be rational numbers or Constant instances')
80 self._coefficients[symbol] = coefficient
81 if isinstance(constant, Constant):
82 constant = constant.constant
83 if not isinstance(constant, numbers.Rational):
84 raise TypeError('constant must be a rational number or a Constant instance')
85 self._constant = constant
86 self._symbols = tuple(sorted(self._coefficients))
87 self._dimension = len(self._symbols)
88 return self
89
90 @classmethod
91 def fromstring(cls, string):
92 raise NotImplementedError
93
94 @property
95 def symbols(self):
96 return self._symbols
97
98 @property
99 def dimension(self):
100 return self._dimension
101
102 def coefficient(self, symbol):
103 if isinstance(symbol, Symbol):
104 symbol = str(symbol)
105 elif not isinstance(symbol, str):
106 raise TypeError('symbol must be a string or a Symbol instance')
107 try:
108 return self._coefficients[symbol]
109 except KeyError:
110 return 0
111
112 __getitem__ = coefficient
113
114 def coefficients(self):
115 for symbol in self.symbols:
116 yield symbol, self.coefficient(symbol)
117
118 @property
119 def constant(self):
120 return self._constant
121
122 def isconstant(self):
123 return False
124
125 def values(self):
126 for symbol in self.symbols:
127 yield self.coefficient(symbol)
128 yield self.constant
129
130 @property
131 def symbol(self):
132 raise ValueError('not a symbol: {}'.format(self))
133
134 def issymbol(self):
135 return False
136
137 def __bool__(self):
138 return True
139
140 def __pos__(self):
141 return self
142
143 def __neg__(self):
144 return self * -1
145
146 @_polymorphic_method
147 def __add__(self, other):
148 coefficients = dict(self.coefficients())
149 for symbol, coefficient in other.coefficients():
150 if symbol in coefficients:
151 coefficients[symbol] += coefficient
152 else:
153 coefficients[symbol] = coefficient
154 constant = self.constant + other.constant
155 return Expression(coefficients, constant)
156
157 __radd__ = __add__
158
159 @_polymorphic_method
160 def __sub__(self, other):
161 coefficients = dict(self.coefficients())
162 for symbol, coefficient in other.coefficients():
163 if symbol in coefficients:
164 coefficients[symbol] -= coefficient
165 else:
166 coefficients[symbol] = -coefficient
167 constant = self.constant - other.constant
168 return Expression(coefficients, constant)
169
170 def __rsub__(self, other):
171 return -(self - other)
172
173 @_polymorphic_method
174 def __mul__(self, other):
175 if other.isconstant():
176 coefficients = dict(self.coefficients())
177 for symbol in coefficients:
178 coefficients[symbol] *= other.constant
179 constant = self.constant * other.constant
180 return Expression(coefficients, constant)
181 if isinstance(other, Expression) and not self.isconstant():
182 raise ValueError('non-linear expression: '
183 '{} * {}'.format(self._parenstr(), other._parenstr()))
184 return NotImplemented
185
186 __rmul__ = __mul__
187
188 @_polymorphic_method
189 def __truediv__(self, other):
190 if other.isconstant():
191 coefficients = dict(self.coefficients())
192 for symbol in coefficients:
193 coefficients[symbol] = \
194 Fraction(coefficients[symbol], other.constant)
195 constant = Fraction(self.constant, other.constant)
196 return Expression(coefficients, constant)
197 if isinstance(other, Expression):
198 raise ValueError('non-linear expression: '
199 '{} / {}'.format(self._parenstr(), other._parenstr()))
200 return NotImplemented
201
202 def __rtruediv__(self, other):
203 if isinstance(other, self):
204 if self.isconstant():
205 constant = Fraction(other, self.constant)
206 return Expression(constant=constant)
207 else:
208 raise ValueError('non-linear expression: '
209 '{} / {}'.format(other._parenstr(), self._parenstr()))
210 return NotImplemented
211
212 def __str__(self):
213 string = ''
214 i = 0
215 for symbol in self.symbols:
216 coefficient = self.coefficient(symbol)
217 if coefficient == 1:
218 if i == 0:
219 string += symbol
220 else:
221 string += ' + {}'.format(symbol)
222 elif coefficient == -1:
223 if i == 0:
224 string += '-{}'.format(symbol)
225 else:
226 string += ' - {}'.format(symbol)
227 else:
228 if i == 0:
229 string += '{}*{}'.format(coefficient, symbol)
230 elif coefficient > 0:
231 string += ' + {}*{}'.format(coefficient, symbol)
232 else:
233 assert coefficient < 0
234 coefficient *= -1
235 string += ' - {}*{}'.format(coefficient, symbol)
236 i += 1
237 constant = self.constant
238 if constant != 0 and i == 0:
239 string += '{}'.format(constant)
240 elif constant > 0:
241 string += ' + {}'.format(constant)
242 elif constant < 0:
243 constant *= -1
244 string += ' - {}'.format(constant)
245 if string == '':
246 string = '0'
247 return string
248
249 def _parenstr(self, always=False):
250 string = str(self)
251 if not always and (self.isconstant() or self.issymbol()):
252 return string
253 else:
254 return '({})'.format(string)
255
256 def __repr__(self):
257 string = '{}({{'.format(self.__class__.__name__)
258 for i, (symbol, coefficient) in enumerate(self.coefficients()):
259 if i != 0:
260 string += ', '
261 string += '{!r}: {!r}'.format(symbol, coefficient)
262 string += '}}, {!r})'.format(self.constant)
263 return string
264
265 @_polymorphic_method
266 def __eq__(self, other):
267 # "normal" equality
268 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
269 return isinstance(other, Expression) and \
270 self._coefficients == other._coefficients and \
271 self.constant == other.constant
272
273 def __hash__(self):
274 return hash((tuple(sorted(self._coefficients.items())), self._constant))
275
276 def _toint(self):
277 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
278 [value.denominator for value in self.values()])
279 return self * lcm
280
281 @_polymorphic_method
282 def _eq(self, other):
283 return Polyhedron(equalities=[(self - other)._toint()])
284
285 @_polymorphic_method
286 def __le__(self, other):
287 return Polyhedron(inequalities=[(other - self)._toint()])
288
289 @_polymorphic_method
290 def __lt__(self, other):
291 return Polyhedron(inequalities=[(other - self)._toint() - 1])
292
293 @_polymorphic_method
294 def __ge__(self, other):
295 return Polyhedron(inequalities=[(self - other)._toint()])
296
297 @_polymorphic_method
298 def __gt__(self, other):
299 return Polyhedron(inequalities=[(self - other)._toint() - 1])
300
301
302 class Constant(Expression):
303
304 def __new__(cls, numerator=0, denominator=None):
305 self = object().__new__(cls)
306 if denominator is None:
307 if isinstance(numerator, numbers.Rational):
308 self._constant = numerator
309 elif isinstance(numerator, Constant):
310 self._constant = numerator.constant
311 else:
312 raise TypeError('constant must be a rational number or a Constant instance')
313 else:
314 self._constant = Fraction(numerator, denominator)
315 self._coefficients = {}
316 self._symbols = ()
317 self._dimension = 0
318 return self
319
320 def isconstant(self):
321 return True
322
323 def __bool__(self):
324 return bool(self.constant)
325
326 def __repr__(self):
327 return '{}({!r})'.format(self.__class__.__name__, self._constant)
328
329
330 class Symbol(Expression):
331
332 def __new__(cls, name):
333 if isinstance(name, Symbol):
334 name = name.symbol
335 elif not isinstance(name, str):
336 raise TypeError('name must be a string or a Symbol instance')
337 self = object().__new__(cls)
338 self._coefficients = {name: 1}
339 self._constant = 0
340 self._symbols = tuple(name)
341 self._symbol = name
342 self._dimension = 1
343 return self
344
345 @property
346 def symbol(self):
347 return self._symbol
348
349 def issymbol(self):
350 return True
351
352 def __repr__(self):
353 return '{}({!r})'.format(self.__class__.__name__, self._symbol)
354
355 def symbols(names):
356 if isinstance(names, str):
357 names = names.replace(',', ' ').split()
358 return (Symbol(name) for name in names)
359
360
361 @_polymorphic_operator
362 def eq(a, b):
363 return a._eq(b)
364
365 @_polymorphic_operator
366 def le(a, b):
367 return a <= b
368
369 @_polymorphic_operator
370 def lt(a, b):
371 return a < b
372
373 @_polymorphic_operator
374 def ge(a, b):
375 return a >= b
376
377 @_polymorphic_operator
378 def gt(a, b):
379 return a > b
380
381
382 class Polyhedron:
383 """
384 This class implements polyhedrons.
385 """
386
387 def __new__(cls, equalities=None, inequalities=None):
388 if isinstance(equalities, str):
389 if inequalities is not None:
390 raise TypeError('too many arguments')
391 return cls.fromstring(equalities)
392 self = super().__new__(cls)
393 self._equalities = []
394 if equalities is not None:
395 for constraint in equalities:
396 for value in constraint.values():
397 if value.denominator != 1:
398 raise TypeError('non-integer constraint: '
399 '{} == 0'.format(constraint))
400 self._equalities.append(constraint)
401 self._equalities = tuple(self._equalities)
402 self._inequalities = []
403 if inequalities is not None:
404 for constraint in inequalities:
405 for value in constraint.values():
406 if value.denominator != 1:
407 raise TypeError('non-integer constraint: '
408 '{} <= 0'.format(constraint))
409 self._inequalities.append(constraint)
410 self._inequalities = tuple(self._inequalities)
411 self._constraints = self._equalities + self._inequalities
412 self._symbols = set()
413 for constraint in self._constraints:
414 self.symbols.update(constraint.symbols)
415 self._symbols = tuple(sorted(self._symbols))
416 return self
417
418 @classmethod
419 def fromstring(cls, string):
420 raise NotImplementedError
421
422 @property
423 def equalities(self):
424 return self._equalities
425
426 @property
427 def inequalities(self):
428 return self._inequalities
429
430 @property
431 def constraints(self):
432 return self._constraints
433
434 @property
435 def symbols(self):
436 return self._symbols
437
438 @property
439 def dimension(self):
440 return len(self.symbols)
441
442 def __bool__(self):
443 return not self.is_empty()
444
445 def __contains__(self, value):
446 # is the value in the polyhedron?
447 raise NotImplementedError
448
449 def __eq__(self, other):
450 raise NotImplementedError
451
452 def isempty(self):
453 bset = self._toisl()
454 return bool(libisl.isl_basic_set_is_empty(bset))
455
456 def isuniverse(self):
457 raise NotImplementedError
458
459 def isdisjoint(self, other):
460 # return true if the polyhedron has no elements in common with other
461 raise NotImplementedError
462
463 def issubset(self, other):
464 raise NotImplementedError
465
466 def __le__(self, other):
467 return self.issubset(other)
468
469 def __lt__(self, other):
470 raise NotImplementedError
471
472 def issuperset(self, other):
473 # test whether every element in other is in the polyhedron
474 raise NotImplementedError
475
476 def __ge__(self, other):
477 return self.issuperset(other)
478
479 def __gt__(self, other):
480 raise NotImplementedError
481
482 def union(self, *others):
483 # return a new polyhedron with elements from the polyhedron and all
484 # others (convex union)
485 raise NotImplementedError
486
487 def __or__(self, other):
488 return self.union(other)
489
490 def intersection(self, *others):
491 # return a new polyhedron with elements common to the polyhedron and all
492 # others
493 # a poor man's implementation could be:
494 # equalities = list(self.equalities)
495 # inequalities = list(self.inequalities)
496 # for other in others:
497 # equalities.extend(other.equalities)
498 # inequalities.extend(other.inequalities)
499 # return self.__class__(equalities, inequalities)
500 raise NotImplementedError
501
502 def __and__(self, other):
503 return self.intersection(other)
504
505 def difference(self, *others):
506 # return a new polyhedron with elements in the polyhedron that are not
507 # in the others
508 raise NotImplementedError
509
510 def __sub__(self, other):
511 return self.difference(other)
512
513 def __str__(self):
514 constraints = []
515 for constraint in self.equalities:
516 constraints.append('{} == 0'.format(constraint))
517 for constraint in self.inequalities:
518 constraints.append('{} >= 0'.format(constraint))
519 return '{{{}}}'.format(', '.join(constraints))
520
521 def __repr__(self):
522 equalities = list(self.equalities)
523 inequalities = list(self.inequalities)
524 return '{}(equalities={!r}, inequalities={!r})' \
525 ''.format(self.__class__.__name__, equalities, inequalities)
526
527 def _symbolunion(self, *others):
528 symbols = set(self.symbols)
529 for other in others:
530 symbols.update(other.symbols)
531 return sorted(symbols)
532
533 def _toisl(self, symbols=None):
534 if symbols is None:
535 symbols = self.symbols
536 num_coefficients = len(symbols)
537 space = libisl.isl_space_set_alloc(_main_ctx, 0, num_coefficients)
538 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
539 ls = libisl.isl_local_space_from_space(space)
540 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
541 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
542 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
543 if list(self.equalities): #check if any equalities exist
544 for eq in self.equalities:
545 coeff_eq = dict(eq.coefficients())
546 if eq.constant:
547 value = eq.constant
548 ceq = libisl.isl_constraint_set_constant_si(ceq, value)
549 for eq in coeff_eq:
550 num = coeff_eq.get(eq)
551 iden = symbols.index(eq)
552 ceq = libisl.isl_constraint_set_coefficient_si(ceq, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
553 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
554 if list(self.inequalities): #check if any inequalities exist
555 for ineq in self.inequalities:
556 coeff_in = dict(ineq.coefficients())
557 if ineq.constant:
558 value = ineq.constant
559 cin = libisl.isl_constraint_set_constant_si(cin, value)
560 for ineq in coeff_in:
561 num = coeff_in.get(ineq)
562 iden = symbols.index(ineq)
563 cin = libisl.isl_constraint_set_coefficient_si(cin, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
564 bset = libisl.isl_basic_set_add_constraint(bset, cin)
565 bset = isl.BasicSet(bset)
566 return bset
567
568 @classmethod
569 def _fromisl(cls, bset):
570 raise NotImplementedError
571 equalities = ...
572 inequalities = ...
573 return cls(equalities, inequalities)
574 '''takes basic set in isl form and puts back into python version of polyhedron
575 isl example code gives isl form as:
576 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
577 our printer is giving form as:
578 b'{ [i0] : 1 = 0 }' '''
579 #bset = self
580 # if self._equalities:
581 # constraints = libisl.isl_basic_set_equalities_matrix(bset, 3)
582 # elif self._inequalities:
583 # constraints = libisl.isl_basic_set_inequalities_matrix(bset, 3)
584 # print(constraints)
585 # return constraints
586
587 empty = None #eq(0,1)
588 universe = None #Polyhedron()
589
590
591 if __name__ == '__main__':
592 ex1 = Expression(coefficients={'a': 1, 'x': 2}, constant=2)
593 ex2 = Expression(coefficients={'a': 3 , 'b': 2}, constant=3)
594 p = Polyhedron(inequalities=[ex1, ex2])
595 bs = p._toisl()
596 print(bs)
597 print('empty ?', p.isempty())
598 print('empty ?', eq(0, 1).isempty())