9584609e6878e142fd66302f281fb7ae44c69462
[linpy.git] / pypol / linear.py
1 import functools
2 import numbers
3 import json
4 import ctypes, ctypes.util
5 from pypol import isl
6 from . import isl, islhelper
7
8 from fractions import Fraction, gcd
9
10 libisl = ctypes.CDLL(ctypes.util.find_library('isl'))
11
12 libisl.isl_printer_get_str.restype = ctypes.c_char_p
13
14 __all__ = [
15 'Expression',
16 'constant', 'symbol', 'symbols',
17 'eq', 'le', 'lt', 'ge', 'gt',
18 'Polyhedron',
19 'empty', 'universe'
20 ]
21
22 ids = {}
23
24 def get_ids(co):
25 if co in ids:
26 return ids.get(co)
27 else:
28 idd = len(ids)
29 ids[co] = idd
30 #print(ids)
31 return idd
32
33 def _polymorphic_method(func):
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(b, Expression):
37 return func(a, b)
38 if isinstance(b, numbers.Rational):
39 b = constant(b)
40 return func(a, b)
41 return NotImplemented
42 return wrapper
43
44 def _polymorphic_operator(func):
45 # A polymorphic operator should call a polymorphic method, hence we just
46 # have to test the left operand.
47 @functools.wraps(func)
48 def wrapper(a, b):
49 if isinstance(a, numbers.Rational):
50 a = constant(a)
51 return func(a, b)
52 elif isinstance(a, Expression):
53 return func(a, b)
54 raise TypeError('arguments must be linear expressions')
55 return wrapper
56
57 class Context:
58
59 __slots__ = ('_ic')
60
61 def __init__(self):
62 self._ic = libisl.isl_ctx_alloc()
63
64 @property
65 def _as_parameter_(self):
66 return self._ic
67
68 #comment out so does not delete itself after being created
69 #def __del__(self):
70 # libisl.isl_ctx_free(self)
71
72 def __eq__(self, other):
73 if not isinstance(other, Context):
74 return False
75 return self._ic == other._ic
76
77
78
79
80 class Expression:
81 """
82 This class implements linear expressions.
83 """
84
85 def __new__(cls, coefficients=None, constant=0):
86 if isinstance(coefficients, str):
87 if constant:
88 raise TypeError('too many arguments')
89 return cls.fromstring(coefficients)
90 self = super().__new__(cls)
91 self._coefficients = {}
92 if isinstance(coefficients, dict):
93 coefficients = coefficients.items()
94 if coefficients is not None:
95 for symbol, coefficient in coefficients:
96 if isinstance(symbol, Expression) and symbol.issymbol():
97 symbol = str(symbol)
98 elif not isinstance(symbol, str):
99 raise TypeError('symbols must be strings')
100 if not isinstance(coefficient, numbers.Rational):
101 raise TypeError('coefficients must be rational numbers')
102 if coefficient != 0:
103 self._coefficients[symbol] = coefficient
104 if not isinstance(constant, numbers.Rational):
105 raise TypeError('constant must be a rational number')
106 self._constant = constant
107 return self
108
109
110 def symbols(self):
111 yield from sorted(self._coefficients)
112
113 @property
114 def dimension(self):
115 return len(list(self.symbols()))
116
117 def coefficient(self, symbol):
118 if isinstance(symbol, Expression) and symbol.issymbol():
119 symbol = str(symbol)
120 elif not isinstance(symbol, str):
121 raise TypeError('symbol must be a string')
122 try:
123 return self._coefficients[symbol]
124 except KeyError:
125 return 0
126
127 __getitem__ = coefficient
128
129 @property
130 def coefficients(self):
131 for symbol in self.symbols():
132 yield symbol, self.coefficient(symbol)
133
134 @property
135 def constant(self):
136 return self._constant
137
138 def isconstant(self):
139 return len(self._coefficients) == 0
140
141 def values(self):
142 for symbol in self.symbols():
143 yield self.coefficient(symbol)
144 yield self.constant
145
146 def values_int(self):
147 for symbol in self.symbols():
148 return self.coefficient(symbol)
149 return int(self.constant)
150
151
152 def symbol(self):
153 if not self.issymbol():
154 raise ValueError('not a symbol: {}'.format(self))
155 for symbol in self.symbols():
156 return symbol
157
158 def issymbol(self):
159 return len(self._coefficients) == 1 and self._constant == 0
160
161 def __bool__(self):
162 return (not self.isconstant()) or bool(self.constant)
163
164 def __pos__(self):
165 return self
166
167 def __neg__(self):
168 return self * -1
169
170 @_polymorphic_method
171 def __add__(self, other):
172 coefficients = dict(self.coefficients)
173 for symbol, coefficient in other.coefficients:
174 if symbol in coefficients:
175 coefficients[symbol] += coefficient
176 else:
177 coefficients[symbol] = coefficient
178 constant = self.constant + other.constant
179 return Expression(coefficients, constant)
180
181 __radd__ = __add__
182
183 @_polymorphic_method
184 def __sub__(self, other):
185 coefficients = dict(self.coefficients)
186 for symbol, coefficient in other.coefficients:
187 if symbol in coefficients:
188 coefficients[symbol] -= coefficient
189 else:
190 coefficients[symbol] = -coefficient
191 constant = self.constant - other.constant
192 return Expression(coefficients, constant)
193
194 def __rsub__(self, other):
195 return -(self - other)
196
197 @_polymorphic_method
198 def __mul__(self, other):
199 if other.isconstant():
200 coefficients = dict(self.coefficients)
201 for symbol in coefficients:
202 coefficients[symbol] *= other.constant
203 constant = self.constant * other.constant
204 return Expression(coefficients, constant)
205 if isinstance(other, Expression) and not self.isconstant():
206 raise ValueError('non-linear expression: '
207 '{} * {}'.format(self._parenstr(), other._parenstr()))
208 return NotImplemented
209
210 __rmul__ = __mul__
211
212 @_polymorphic_method
213 def __truediv__(self, other):
214 if other.isconstant():
215 coefficients = dict(self.coefficients())
216 for symbol in coefficients:
217 coefficients[symbol] = \
218 Fraction(coefficients[symbol], other.constant)
219 constant = Fraction(self.constant, other.constant)
220 return Expression(coefficients, constant)
221 if isinstance(other, Expression):
222 raise ValueError('non-linear expression: '
223 '{} / {}'.format(self._parenstr(), other._parenstr()))
224 return NotImplemented
225
226 def __rtruediv__(self, other):
227 if isinstance(other, self):
228 if self.isconstant():
229 constant = Fraction(other, self.constant)
230 return Expression(constant=constant)
231 else:
232 raise ValueError('non-linear expression: '
233 '{} / {}'.format(other._parenstr(), self._parenstr()))
234 return NotImplemented
235
236 def __str__(self):
237 string = ''
238 symbols = sorted(self.symbols())
239 i = 0
240 for symbol in symbols:
241 coefficient = self[symbol]
242 if coefficient == 1:
243 if i == 0:
244 string += symbol
245 else:
246 string += ' + {}'.format(symbol)
247 elif coefficient == -1:
248 if i == 0:
249 string += '-{}'.format(symbol)
250 else:
251 string += ' - {}'.format(symbol)
252 else:
253 if i == 0:
254 string += '{}*{}'.format(coefficient, symbol)
255 elif coefficient > 0:
256 string += ' + {}*{}'.format(coefficient, symbol)
257 else:
258 assert coefficient < 0
259 coefficient *= -1
260 string += ' - {}*{}'.format(coefficient, symbol)
261 i += 1
262 constant = self.constant
263 if constant != 0 and i == 0:
264 string += '{}'.format(constant)
265 elif constant > 0:
266 string += ' + {}'.format(constant)
267 elif constant < 0:
268 constant *= -1
269 string += ' - {}'.format(constant)
270 if string == '':
271 string = '0'
272 return string
273
274 def _parenstr(self, always=False):
275 string = str(self)
276 if not always and (self.isconstant() or self.issymbol()):
277 return string
278 else:
279 return '({})'.format(string)
280
281 def __repr__(self):
282 string = '{}({{'.format(self.__class__.__name__)
283 for i, (symbol, coefficient) in enumerate(self.coefficients):
284 if i != 0:
285 string += ', '
286 string += '{!r}: {!r}'.format(symbol, coefficient)
287 string += '}}, {!r})'.format(self.constant)
288 return string
289
290 @classmethod
291 def fromstring(cls, string):
292 raise NotImplementedError
293
294 @_polymorphic_method
295 def __eq__(self, other):
296 # "normal" equality
297 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
298 return isinstance(other, Expression) and \
299 self._coefficients == other._coefficients and \
300 self.constant == other.constant
301
302 def __hash__(self):
303 return hash((self._coefficients, self._constant))
304
305 def _canonify(self):
306 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
307 [value.denominator for value in self.values()])
308 return self * lcm
309
310 @_polymorphic_method
311 def _eq(self, other):
312 return Polyhedron(equalities=[(self - other)._canonify()])
313
314 @_polymorphic_method
315 def __le__(self, other):
316 return Polyhedron(inequalities=[(other - self)._canonify()])
317
318 @_polymorphic_method
319 def __lt__(self, other):
320 return Polyhedron(inequalities=[(other - self)._canonify() - 1])
321
322 @_polymorphic_method
323 def __ge__(self, other):
324 return Polyhedron(inequalities=[(self - other)._canonify()])
325
326 @_polymorphic_method
327 def __gt__(self, other):
328 return Polyhedron(inequalities=[(self - other)._canonify() - 1])
329
330
331 def constant(numerator=0, denominator=None):
332 if denominator is None and isinstance(numerator, numbers.Rational):
333 return Expression(constant=numerator)
334 else:
335 return Expression(constant=Fraction(numerator, denominator))
336
337 def symbol(name):
338 if not isinstance(name, str):
339 raise TypeError('name must be a string')
340 return Expression(coefficients={name: 1})
341
342 def symbols(names):
343 if isinstance(names, str):
344 names = names.replace(',', ' ').split()
345 return (symbol(name) for name in names)
346
347
348 @_polymorphic_operator
349 def eq(a, b):
350 return a._eq(b)
351
352 @_polymorphic_operator
353 def le(a, b):
354 return a <= b
355
356 @_polymorphic_operator
357 def lt(a, b):
358 return a < b
359
360 @_polymorphic_operator
361 def ge(a, b):
362 return a >= b
363
364 @_polymorphic_operator
365 def gt(a, b):
366 return a > b
367
368
369 class Polyhedron:
370 """
371 This class implements polyhedrons.
372 """
373
374 def __new__(cls, equalities=None, inequalities=None):
375 if isinstance(equalities, str):
376 if inequalities is not None:
377 raise TypeError('too many arguments')
378 return cls.fromstring(equalities)
379 self = super().__new__(cls)
380 self._equalities = []
381 if equalities is not None:
382 for constraint in equalities:
383 for value in constraint.values():
384 if value.denominator != 1:
385 raise TypeError('non-integer constraint: '
386 '{} == 0'.format(constraint))
387 self._equalities.append(constraint)
388 self._inequalities = []
389 if inequalities is not None:
390 for constraint in inequalities:
391 for value in constraint.values():
392 if value.denominator != 1:
393 raise TypeError('non-integer constraint: '
394 '{} <= 0'.format(constraint))
395 self._inequalities.append(constraint)
396 self._bset = self.to_isl()
397 #print(self._bset)
398 #put this here just to test from isl method
399 #from_isl = self.from_isl(self._bset)
400 #print(from_isl)
401 #rint(self)
402 return self
403
404
405 @property
406 def equalities(self):
407 yield from self._equalities
408
409 @property
410 def inequalities(self):
411 yield from self._inequalities
412
413 @property
414 def constant(self):
415 return self._constant
416
417 def isconstant(self):
418 return len(self._coefficients) == 0
419
420
421 def isempty(self):
422 return bool(libisl.isl_basic_set_is_empty(self._bset))
423
424 def constraints(self):
425 yield from self.equalities
426 yield from self.inequalities
427
428
429 def symbols(self):
430 s = set()
431 for constraint in self.constraints():
432 s.update(constraint.symbols())
433 return sorted(s)
434
435 @property
436 def dimension(self):
437 return len(self.symbols())
438
439 def __bool__(self):
440 # return false if the polyhedron is empty, true otherwise
441 if self._equalities or self._inequalities:
442 return False
443 else:
444 return True
445
446
447 def __contains__(self, value):
448 # is the value in the polyhedron?
449 raise NotImplementedError
450
451 def __eq__(self, other):
452 raise NotImplementedError
453
454 def is_empty(self):
455 return
456
457 def isuniverse(self):
458 return self == universe
459
460 def isdisjoint(self, other):
461 # return true if the polyhedron has no elements in common with other
462 raise NotImplementedError
463
464 def issubset(self, other):
465 raise NotImplementedError
466
467 def __le__(self, other):
468 return self.issubset(other)
469
470 def __lt__(self, other):
471 raise NotImplementedError
472
473 def issuperset(self, other):
474 # test whether every element in other is in the polyhedron
475 for value in other:
476 if value == self.constraints():
477 return True
478 else:
479 return False
480 raise NotImplementedError
481
482 def __ge__(self, other):
483 return self.issuperset(other)
484
485 def __gt__(self, other):
486 raise NotImplementedError
487
488 def union(self, *others):
489 # return a new polyhedron with elements from the polyhedron and all
490 # others (convex union)
491 raise NotImplementedError
492
493 def __or__(self, other):
494 return self.union(other)
495
496 def intersection(self, *others):
497 # return a new polyhedron with elements common to the polyhedron and all
498 # others
499 # a poor man's implementation could be:
500 # equalities = list(self.equalities)
501 # inequalities = list(self.inequalities)
502 # for other in others:
503 # equalities.extend(other.equalities)
504 # inequalities.extend(other.inequalities)
505 # return self.__class__(equalities, inequalities)
506 raise NotImplementedError
507
508 def __and__(self, other):
509 return self.intersection(other)
510
511 def difference(self, *others):
512 # return a new polyhedron with elements in the polyhedron that are not
513 # in the others
514 raise NotImplementedError
515
516 def __sub__(self, other):
517 return self.difference(other)
518
519 def __str__(self):
520 constraints = []
521 for constraint in self.equalities:
522 constraints.append('{} == 0'.format(constraint))
523 for constraint in self.inequalities:
524 constraints.append('{} >= 0'.format(constraint))
525 return '{{{}}}'.format(', '.join(constraints))
526
527 def __repr__(self):
528 equalities = list(self.equalities)
529 inequalities = list(self.inequalities)
530 return '{}(equalities={!r}, inequalities={!r})' \
531 ''.format(self.__class__.__name__, equalities, inequalities)
532
533 @classmethod
534 def fromstring(cls, string):
535 raise NotImplementedError
536
537 def _symbolunion(self, *others):
538 print(self)
539 symbols = set(self.symbols())
540 for other in others:
541 symbols.update(other.symbols())
542 return sorted(symbols)
543
544 def to_isl(self, symbols=None):
545 if symbols is None:
546 symbols = self.symbols()
547 print(symbols)
548 print('>>>', self)
549 print('eq:', list(self.equalities))
550 print('ineq:', list(self.inequalities))
551 num_coefficients = len(symbols)
552 ctx = Context()
553 space = libisl.isl_space_set_alloc(ctx, 0, num_coefficients)
554 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
555 ls = libisl.isl_local_space_from_space(space)
556 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
557 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
558 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
559 if list(self.equalities): #check if any equalities exist
560 for eq in self.equalities:
561 coeff_eq = dict(eq.coefficients)
562 if eq.constant:
563 value = eq.constant
564 ceq = libisl.isl_constraint_set_constant_si(ceq, value)
565 for eq in coeff_eq:
566 num = coeff_eq.get(eq)
567 iden = symbols.index(eq)
568 print('id of var {} is {}, coeff is {}'.format(eq, iden, num))
569 ceq = libisl.isl_constraint_set_coefficient_si(ceq, islhelper.isl_dim_set, iden, num) #use 3 for type isl_dim_set
570 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
571 if list(self.inequalities): #check if any inequalities exist
572 for ineq in self.inequalities:
573 coeff_in = dict(ineq.coefficients)
574 if ineq.constant:
575 value = ineq.constant
576 cin = libisl.isl_constraint_set_constant_si(cin, value)
577 for ineq in coeff_in:
578 num = coeff_in.get(ineq)
579 iden = symbols.index(ineq)
580 print('id of var {} is {}, coeff is {}'.format(ineq, iden, num))
581 cin = libisl.isl_constraint_set_coefficient_si(cin, islhelper.isl_dim_set, iden, num) #use 3 for type isl_dim_set
582 bset = libisl.isl_basic_set_add_constraint(bset, cin)
583 ip = libisl.isl_printer_to_str(ctx) #create string printer
584 ip = libisl.isl_printer_print_basic_set(ip, bset) #print basic set to printer
585 string = libisl.isl_printer_get_str(ip) #get string from printer
586 string = str(string.decode())
587 print(string)
588 return bset
589
590
591 def from_isl(self, bset):
592 '''takes basic set in isl form and puts back into python version of polyhedron
593 isl example code gives isl form as:
594 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
595 our printer is giving form as:
596 b'{ [i0] : 1 = 0 }' '''
597 #bset = self
598 if self._equalities:
599 constraints = libisl.isl_basic_set_equalities_matrix(bset, 3)
600 elif self._inequalities:
601 constraints = libisl.isl_basic_set_inequalities_matrix(bset, 3)
602 print(constraints)
603 return constraints
604
605 #empty = eq(0,1)
606 empty = None
607 universe = None
608
609 if __name__ == '__main__':
610 ex1 = Expression(coefficients={'a': 1, 'x': 2}, constant=2)
611 ex2 = Expression(coefficients={'a': 3 , 'b': 2}, constant=3)
612 p = Polyhedron(inequalities=[ex1, ex2])
613 #p = eq(ex2, 0)# 2a+4 = 0, in fact 6a+3 = 0
614 #p.to_isl()
615
616 #universe = Polyhedron()