be30fad7528b21b0b330f6f2b6a50e1fca49ed2b
[linpy.git] / pypol / linear.py
1 import functools
2 import numbers
3 import json
4 import ctypes, ctypes.util
5 from pypol import isl, islhelper
6
7 from fractions import Fraction, gcd
8
9 libisl = ctypes.CDLL(ctypes.util.find_library('isl'))
10
11 libisl.isl_printer_get_str.restype = ctypes.c_char_p
12
13 __all__ = [
14 'Expression',
15 'constant', 'symbol', 'symbols',
16 'eq', 'le', 'lt', 'ge', 'gt',
17 'Polyhedron',
18 'empty', 'universe'
19 ]
20
21 '''
22 def symbolToInt(self):
23 make dictionary of key:value (letter:integer)
24 iterate through the dictionary to find matching symbol
25 return the given integer value
26 d = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 6, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14,
27 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}
28 if self in d:
29 num = d.get(self)
30 return num
31 '''
32
33 ids = {}
34
35 def get_ids(co):
36 if co in ids:
37 return ids.get(co)
38 else:
39 idd = len(ids)
40 ids[co] = idd
41 print(ids)
42 return idd
43
44 def _polymorphic_method(func):
45 @functools.wraps(func)
46 def wrapper(a, b):
47 if isinstance(b, Expression):
48 return func(a, b)
49 if isinstance(b, numbers.Rational):
50 b = constant(b)
51 return func(a, b)
52 return NotImplemented
53 return wrapper
54
55 def _polymorphic_operator(func):
56 # A polymorphic operator should call a polymorphic method, hence we just
57 # have to test the left operand.
58 @functools.wraps(func)
59 def wrapper(a, b):
60 if isinstance(a, numbers.Rational):
61 a = constant(a)
62 return func(a, b)
63 elif isinstance(a, Expression):
64 return func(a, b)
65 raise TypeError('arguments must be linear expressions')
66 return wrapper
67
68 class Context:
69
70 __slots__ = ('_ic')
71
72 def __init__(self):
73 self._ic = libisl.isl_ctx_alloc()
74
75 @property
76 def _as_parameter_(self):
77 return self._ic
78
79 #comment out so does not delete itself after being created
80 #def __del__(self):
81 # libisl.isl_ctx_free(self)
82
83 def __eq__(self, other):
84 if not isinstance(other, Context):
85 return False
86 return self._ic == other._ic
87
88
89
90
91 class Expression:
92 """
93 This class implements linear expressions.
94 """
95
96 def __new__(cls, coefficients=None, constant=0):
97 if isinstance(coefficients, str):
98 if constant:
99 raise TypeError('too many arguments')
100 return cls.fromstring(coefficients)
101 self = super().__new__(cls)
102 self._coefficients = {}
103 if isinstance(coefficients, dict):
104 coefficients = coefficients.items()
105 if coefficients is not None:
106 for symbol, coefficient in coefficients:
107 if isinstance(symbol, Expression) and symbol.issymbol():
108 symbol = str(symbol)
109 elif not isinstance(symbol, str):
110 raise TypeError('symbols must be strings')
111 if not isinstance(coefficient, numbers.Rational):
112 raise TypeError('coefficients must be rational numbers')
113 if coefficient != 0:
114 self._coefficients[symbol] = coefficient
115 if not isinstance(constant, numbers.Rational):
116 raise TypeError('constant must be a rational number')
117 self._constant = constant
118 return self
119
120
121 def symbols(self):
122 yield from sorted(self._coefficients)
123
124 @property
125 def dimension(self):
126 return len(list(self.symbols()))
127
128 def coefficient(self, symbol):
129 if isinstance(symbol, Expression) and symbol.issymbol():
130 symbol = str(symbol)
131 elif not isinstance(symbol, str):
132 raise TypeError('symbol must be a string')
133 try:
134 return self._coefficients[symbol]
135 except KeyError:
136 return 0
137
138 __getitem__ = coefficient
139
140 def coefficients(self):
141 for symbol in self.symbols():
142 yield symbol, self.coefficient(symbol)
143
144 @property
145 def constant(self):
146 return self._constant
147
148 def isconstant(self):
149 return len(self._coefficients) == 0
150
151 def values(self):
152 for symbol in self.symbols():
153 yield self.coefficient(symbol)
154 yield self.constant
155
156 def values_int(self):
157 for symbol in self.symbols():
158 return self.coefficient(symbol)
159 return int(self.constant)
160
161
162 def symbol(self):
163 if not self.issymbol():
164 raise ValueError('not a symbol: {}'.format(self))
165 for symbol in self.symbols():
166 return symbol
167
168 def issymbol(self):
169 return len(self._coefficients) == 1 and self._constant == 0
170
171 def __bool__(self):
172 return (not self.isconstant()) or bool(self.constant)
173
174 def __pos__(self):
175 return self
176
177 def __neg__(self):
178 return self * -1
179
180 @_polymorphic_method
181 def __add__(self, other):
182 coefficients = dict(self.coefficients())
183 for symbol, coefficient in other.coefficients():
184 if symbol in coefficients:
185 coefficients[symbol] += coefficient
186 else:
187 coefficients[symbol] = coefficient
188 constant = self.constant + other.constant
189 return Expression(coefficients, constant)
190
191 __radd__ = __add__
192
193 @_polymorphic_method
194 def __sub__(self, other):
195 coefficients = dict(self.coefficients())
196 for symbol, coefficient in other.coefficients():
197 if symbol in coefficients:
198 coefficients[symbol] -= coefficient
199 else:
200 coefficients[symbol] = -coefficient
201 constant = self.constant - other.constant
202 return Expression(coefficients, constant)
203
204 def __rsub__(self, other):
205 return -(self - other)
206
207 @_polymorphic_method
208 def __mul__(self, other):
209 if other.isconstant():
210 coefficients = dict(self.coefficients())
211 for symbol in coefficients:
212 coefficients[symbol] *= other.constant
213 constant = self.constant * other.constant
214 return Expression(coefficients, constant)
215 if isinstance(other, Expression) and not self.isconstant():
216 raise ValueError('non-linear expression: '
217 '{} * {}'.format(self._parenstr(), other._parenstr()))
218 return NotImplemented
219
220 __rmul__ = __mul__
221
222 @_polymorphic_method
223 def __truediv__(self, other):
224 if other.isconstant():
225 coefficients = dict(self.coefficients())
226 for symbol in coefficients:
227 coefficients[symbol] = \
228 Fraction(coefficients[symbol], other.constant)
229 constant = Fraction(self.constant, other.constant)
230 return Expression(coefficients, constant)
231 if isinstance(other, Expression):
232 raise ValueError('non-linear expression: '
233 '{} / {}'.format(self._parenstr(), other._parenstr()))
234 return NotImplemented
235
236 def __rtruediv__(self, other):
237 if isinstance(other, self):
238 if self.isconstant():
239 constant = Fraction(other, self.constant)
240 return Expression(constant=constant)
241 else:
242 raise ValueError('non-linear expression: '
243 '{} / {}'.format(other._parenstr(), self._parenstr()))
244 return NotImplemented
245
246 def __str__(self):
247 string = ''
248 symbols = sorted(self.symbols())
249 i = 0
250 for symbol in symbols:
251 coefficient = self[symbol]
252 if coefficient == 1:
253 if i == 0:
254 string += symbol
255 else:
256 string += ' + {}'.format(symbol)
257 elif coefficient == -1:
258 if i == 0:
259 string += '-{}'.format(symbol)
260 else:
261 string += ' - {}'.format(symbol)
262 else:
263 if i == 0:
264 string += '{}*{}'.format(coefficient, symbol)
265 elif coefficient > 0:
266 string += ' + {}*{}'.format(coefficient, symbol)
267 else:
268 assert coefficient < 0
269 coefficient *= -1
270 string += ' - {}*{}'.format(coefficient, symbol)
271 i += 1
272 constant = self.constant
273 if constant != 0 and i == 0:
274 string += '{}'.format(constant)
275 elif constant > 0:
276 string += ' + {}'.format(constant)
277 elif constant < 0:
278 constant *= -1
279 string += ' - {}'.format(constant)
280 if string == '':
281 string = '0'
282 return string
283
284 def _parenstr(self, always=False):
285 string = str(self)
286 if not always and (self.isconstant() or self.issymbol()):
287 return string
288 else:
289 return '({})'.format(string)
290
291 def __repr__(self):
292 string = '{}({{'.format(self.__class__.__name__)
293 for i, (symbol, coefficient) in enumerate(self.coefficients()):
294 if i != 0:
295 string += ', '
296 string += '{!r}: {!r}'.format(symbol, coefficient)
297 string += '}}, {!r})'.format(self.constant)
298 return string
299
300 @classmethod
301 def fromstring(cls, string):
302 raise NotImplementedError
303
304 @_polymorphic_method
305 def __eq__(self, other):
306 # "normal" equality
307 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
308 return isinstance(other, Expression) and \
309 self._coefficients == other._coefficients and \
310 self.constant == other.constant
311
312 def __hash__(self):
313 return hash((self._coefficients, self._constant))
314
315 def _canonify(self):
316 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
317 [value.denominator for value in self.values()])
318 return self * lcm
319
320 @_polymorphic_method
321 def _eq(self, other):
322 return Polyhedron(equalities=[(self - other)._canonify()])
323
324 @_polymorphic_method
325 def __le__(self, other):
326 return Polyhedron(inequalities=[(self - other)._canonify()])
327
328 @_polymorphic_method
329 def __lt__(self, other):
330 return Polyhedron(inequalities=[(self - other)._canonify() + 1])
331
332 @_polymorphic_method
333 def __ge__(self, other):
334 return Polyhedron(inequalities=[(other - self)._canonify()])
335
336 @_polymorphic_method
337 def __gt__(self, other):
338 return Polyhedron(inequalities=[(other - self)._canonify() + 1])
339
340
341 def constant(numerator=0, denominator=None):
342 if denominator is None and isinstance(numerator, numbers.Rational):
343 return Expression(constant=3)
344 else:
345 return Expression(constant=Fraction(numerator, denominator))
346
347 def symbol(name):
348 if not isinstance(name, str):
349 raise TypeError('name must be a string')
350 return Expression(coefficients={name: 1})
351
352 def symbols(names):
353 if isinstance(names, str):
354 names = names.replace(',', ' ').split()
355 return (symbol(name) for name in names)
356
357
358 @_polymorphic_operator
359 def eq(a, b):
360 return a._eq(b)
361
362 @_polymorphic_operator
363 def le(a, b):
364 return a <= b
365
366 @_polymorphic_operator
367 def lt(a, b):
368 return a < b
369
370 @_polymorphic_operator
371 def ge(a, b):
372 return a >= b
373
374 @_polymorphic_operator
375 def gt(a, b):
376 return a > b
377
378
379 class Polyhedron:
380 """
381 This class implements polyhedrons.
382 """
383
384 def __new__(cls, equalities=None, inequalities=None):
385 if isinstance(equalities, str):
386 if inequalities is not None:
387 raise TypeError('too many arguments')
388 return cls.fromstring(equalities)
389 self = super().__new__(cls)
390 self._equalities = []
391 if equalities is not None:
392 for constraint in equalities:
393 for value in constraint.values():
394 if value.denominator != 1:
395 raise TypeError('non-integer constraint: '
396 '{} == 0'.format(constraint))
397 self._equalities.append(constraint)
398 self._inequalities = []
399 if inequalities is not None:
400 for constraint in inequalities:
401 for value in constraint.values():
402 if value.denominator != 1:
403 raise TypeError('non-integer constraint: '
404 '{} <= 0'.format(constraint))
405 self._inequalities.append(constraint)
406 self._bset = self.to_isl()
407 #print(self._bset)
408 #put this here just to test from isl method
409 #from_isl = self.from_isl(self._bset)
410 #print(from_isl)
411 #rint(self)
412 return self._bset
413
414
415 @property
416 def equalities(self):
417 yield from self._equalities
418
419 @property
420 def inequalities(self):
421 yield from self._inequalities
422
423 @property
424 def constant(self):
425 return self._constant
426
427 def isconstant(self):
428 return len(self._coefficients) == 0
429
430
431 def isempty(self):
432 return bool(libisl.isl_basic_set_is_empty(self._bset))
433
434 def constraints(self):
435 yield from self.equalities
436 yield from self.inequalities
437
438
439 def symbols(self):
440 s = set()
441 for constraint in self.constraints():
442 s.update(constraint.symbols)
443 yield from sorted(s)
444
445 @property
446 def dimension(self):
447 return len(self.symbols())
448
449 def __bool__(self):
450 # return false if the polyhedron is empty, true otherwise
451 if self._equalities or self._inequalities:
452 return False
453 else:
454 return True
455
456
457 def __contains__(self, value):
458 # is the value in the polyhedron?
459 raise NotImplementedError
460
461 def __eq__(self, other):
462 raise NotImplementedError
463
464 def is_empty(self):
465 return
466
467 def isuniverse(self):
468 return self == universe
469
470 def isdisjoint(self, other):
471 # return true if the polyhedron has no elements in common with other
472 raise NotImplementedError
473
474 def issubset(self, other):
475 raise NotImplementedError
476
477 def __le__(self, other):
478 return self.issubset(other)
479
480 def __lt__(self, other):
481 raise NotImplementedError
482
483 def issuperset(self, other):
484 # test whether every element in other is in the polyhedron
485 for value in other:
486 if value == self.constraints():
487 return True
488 else:
489 return False
490 raise NotImplementedError
491
492 def __ge__(self, other):
493 return self.issuperset(other)
494
495 def __gt__(self, other):
496 raise NotImplementedError
497
498 def union(self, *others):
499 # return a new polyhedron with elements from the polyhedron and all
500 # others (convex union)
501 raise NotImplementedError
502
503 def __or__(self, other):
504 return self.union(other)
505
506 def intersection(self, *others):
507 # return a new polyhedron with elements common to the polyhedron and all
508 # others
509 # a poor man's implementation could be:
510 # equalities = list(self.equalities)
511 # inequalities = list(self.inequalities)
512 # for other in others:
513 # equalities.extend(other.equalities)
514 # inequalities.extend(other.inequalities)
515 # return self.__class__(equalities, inequalities)
516 raise NotImplementedError
517
518 def __and__(self, other):
519 return self.intersection(other)
520
521 def difference(self, *others):
522 # return a new polyhedron with elements in the polyhedron that are not
523 # in the others
524 raise NotImplementedError
525
526 def __sub__(self, other):
527 return self.difference(other)
528
529 def __str__(self):
530 constraints = []
531 for constraint in self.equalities:
532 constraints.append('{} == 0'.format(constraint))
533 for constraint in self.inequalities:
534 constraints.append('{} <= 0'.format(constraint))
535 return '{{{}}}'.format(', '.join(constraints))
536
537 def __repr__(self):
538 equalities = list(self.equalities)
539 inequalities = list(self.inequalities)
540 return '{}(equalities={!r}, inequalities={!r})' \
541 ''.format(self.__class__.__name__, equalities, inequalities)
542
543 @classmethod
544 def fromstring(cls, string):
545 raise NotImplementedError
546
547 def to_isl(self):
548 #d = Expression().__dict__ #write expression values to dictionary in form {'_constant': value, '_coefficients': value}
549 d = {'_constant': 2, '_coefficients': {'b':1}}
550 coeff = d.get('_coefficients')
551 num_coefficients = len(coeff)
552 space = libisl.isl_space_set_alloc(Context(), 0, num_coefficients)
553 bset = libisl.isl_basic_set_empty(libisl.isl_space_copy(space))
554 ls = libisl.isl_local_space_from_space(libisl.isl_space_copy(space))
555 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
556 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
557 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set
558 need to change the symbols method to a lookup table for the integer value for each letter that could be a symbol'''
559 if self._equalities:
560 if '_constant' in d:
561 value = d.get('_constant')
562 ceq = libisl.isl_constraint_set_constant_si(ceq, value)
563 if '_coefficients' in d:
564 value_co = d.get('_coefficients')
565 for co in value_co:
566 num = value_co.get(co)
567 ceq = libisl.isl_constraint_set_coefficient_si(ceq, islhelper.isl_dim_set, get_ids(co), num)
568 bset = libisl.isl_set_add_constraint(bset, ceq)
569
570 if self._inequalities:
571 if '_constant' in d:
572 value = d.get('_constant')
573 cin = libisl.isl_constraint_set_constant_si(cin, value)
574 if '_coefficients' in d:
575 value_co = d.get('_coefficients')
576 for co in value_co:
577 num = value_co.get(co)
578 if value_co: #if dictionary not empty add coefficient as to constraint
579 cin = libisl.isl_constraint_set_coefficient_si(cin, islhelper.isl_dim_set, get_ids(co), num)
580 bset = libisl.isl_set_add_constraint(bset, cin)
581 ip = libisl.isl_printer_to_str(Context()) #create string printer
582 ip = libisl.isl_printer_print_set(ip, bset) #print set to printer
583 string = libisl.isl_printer_get_str(ip) #get string from printer
584 string = str(string)
585 print(string)
586 return string
587
588
589 def from_isl(self, bset):
590 '''takes basic set in isl form and puts back into python version of polyhedron
591 isl example code gives idl form as:
592 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}");'''
593
594 poly = 0
595 return poly
596
597 empty = eq(1,1)
598
599
600 universe = Polyhedron()