e69600ff951a699d6c95d6ed7e57c11d0310818e
[linpy.git] / pypol / linear.py
1 import functools
2 import numbers
3 import ctypes, ctypes.util
4
5 from fractions import Fraction, gcd
6
7 from . import isl, islhelper
8
9
10 libisl = ctypes.CDLL(ctypes.util.find_library('isl'))
11
12 libisl.isl_printer_get_str.restype = ctypes.c_char_p
13
14 __all__ = [
15 'Expression',
16 'constant', 'symbol', 'symbols',
17 'eq', 'le', 'lt', 'ge', 'gt',
18 'Polyhedron',
19 'empty', 'universe'
20 ]
21
22 '''
23 def symbolToInt(self):
24 make dictionary of key:value (letter:integer)
25 iterate through the dictionary to find matching symbol
26 return the given integer value
27 d = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 6, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14,
28 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}
29 if self in d:
30 num = d.get(self)
31 return num
32 '''
33
34 ids = {}
35
36 def get_ids(co):
37 if co in ids:
38 return ids.get(co)
39 else:
40 idd = len(ids)
41 ids[co] = idd
42 print(ids)
43 return idd
44
45 def _polymorphic_method(func):
46 @functools.wraps(func)
47 def wrapper(a, b):
48 if isinstance(b, Expression):
49 return func(a, b)
50 if isinstance(b, numbers.Rational):
51 b = constant(b)
52 return func(a, b)
53 return NotImplemented
54 return wrapper
55
56 def _polymorphic_operator(func):
57 # A polymorphic operator should call a polymorphic method, hence we just
58 # have to test the left operand.
59 @functools.wraps(func)
60 def wrapper(a, b):
61 if isinstance(a, numbers.Rational):
62 a = constant(a)
63 return func(a, b)
64 elif isinstance(a, Expression):
65 return func(a, b)
66 raise TypeError('arguments must be linear expressions')
67 return wrapper
68
69 class Context:
70
71 __slots__ = ('_ic')
72
73 def __init__(self):
74 self._ic = libisl.isl_ctx_alloc()
75
76 @property
77 def _as_parameter_(self):
78 return self._ic
79
80 #comment out so does not delete itself after being created
81 #def __del__(self):
82 # libisl.isl_ctx_free(self)
83
84 def __eq__(self, other):
85 if not isinstance(other, Context):
86 return False
87 return self._ic == other._ic
88
89
90
91
92 class Expression:
93 """
94 This class implements linear expressions.
95 """
96
97 def __new__(cls, coefficients=None, constant=0):
98 if isinstance(coefficients, str):
99 if constant:
100 raise TypeError('too many arguments')
101 return cls.fromstring(coefficients)
102 self = super().__new__(cls)
103 self._coefficients = {}
104 if isinstance(coefficients, dict):
105 coefficients = coefficients.items()
106 if coefficients is not None:
107 for symbol, coefficient in coefficients:
108 if isinstance(symbol, Expression) and symbol.issymbol():
109 symbol = str(symbol)
110 elif not isinstance(symbol, str):
111 raise TypeError('symbols must be strings')
112 if not isinstance(coefficient, numbers.Rational):
113 raise TypeError('coefficients must be rational numbers')
114 if coefficient != 0:
115 self._coefficients[symbol] = coefficient
116 if not isinstance(constant, numbers.Rational):
117 raise TypeError('constant must be a rational number')
118 self._constant = constant
119 return self
120
121
122 def symbols(self):
123 yield from sorted(self._coefficients)
124
125 @property
126 def dimension(self):
127 return len(list(self.symbols()))
128
129 def coefficient(self, symbol):
130 if isinstance(symbol, Expression) and symbol.issymbol():
131 symbol = str(symbol)
132 elif not isinstance(symbol, str):
133 raise TypeError('symbol must be a string')
134 try:
135 return self._coefficients[symbol]
136 except KeyError:
137 return 0
138
139 __getitem__ = coefficient
140
141 def coefficients(self):
142 for symbol in self.symbols():
143 yield symbol, self.coefficient(symbol)
144
145 @property
146 def constant(self):
147 return self._constant
148
149 def isconstant(self):
150 return len(self._coefficients) == 0
151
152 def values(self):
153 for symbol in self.symbols():
154 yield self.coefficient(symbol)
155 yield self.constant
156
157 def values_int(self):
158 for symbol in self.symbols():
159 return self.coefficient(symbol)
160 return int(self.constant)
161
162
163 def symbol(self):
164 if not self.issymbol():
165 raise ValueError('not a symbol: {}'.format(self))
166 for symbol in self.symbols():
167 return symbol
168
169 def issymbol(self):
170 return len(self._coefficients) == 1 and self._constant == 0
171
172 def __bool__(self):
173 return (not self.isconstant()) or bool(self.constant)
174
175 def __pos__(self):
176 return self
177
178 def __neg__(self):
179 return self * -1
180
181 @_polymorphic_method
182 def __add__(self, other):
183 coefficients = dict(self.coefficients())
184 for symbol, coefficient in other.coefficients():
185 if symbol in coefficients:
186 coefficients[symbol] += coefficient
187 else:
188 coefficients[symbol] = coefficient
189 constant = self.constant + other.constant
190 return Expression(coefficients, constant)
191
192 __radd__ = __add__
193
194 @_polymorphic_method
195 def __sub__(self, other):
196 coefficients = dict(self.coefficients())
197 for symbol, coefficient in other.coefficients():
198 if symbol in coefficients:
199 coefficients[symbol] -= coefficient
200 else:
201 coefficients[symbol] = -coefficient
202 constant = self.constant - other.constant
203 return Expression(coefficients, constant)
204
205 def __rsub__(self, other):
206 return -(self - other)
207
208 @_polymorphic_method
209 def __mul__(self, other):
210 if other.isconstant():
211 coefficients = dict(self.coefficients())
212 for symbol in coefficients:
213 coefficients[symbol] *= other.constant
214 constant = self.constant * other.constant
215 return Expression(coefficients, constant)
216 if isinstance(other, Expression) and not self.isconstant():
217 raise ValueError('non-linear expression: '
218 '{} * {}'.format(self._parenstr(), other._parenstr()))
219 return NotImplemented
220
221 __rmul__ = __mul__
222
223 @_polymorphic_method
224 def __truediv__(self, other):
225 if other.isconstant():
226 coefficients = dict(self.coefficients())
227 for symbol in coefficients:
228 coefficients[symbol] = \
229 Fraction(coefficients[symbol], other.constant)
230 constant = Fraction(self.constant, other.constant)
231 return Expression(coefficients, constant)
232 if isinstance(other, Expression):
233 raise ValueError('non-linear expression: '
234 '{} / {}'.format(self._parenstr(), other._parenstr()))
235 return NotImplemented
236
237 def __rtruediv__(self, other):
238 if isinstance(other, self):
239 if self.isconstant():
240 constant = Fraction(other, self.constant)
241 return Expression(constant=constant)
242 else:
243 raise ValueError('non-linear expression: '
244 '{} / {}'.format(other._parenstr(), self._parenstr()))
245 return NotImplemented
246
247 def __str__(self):
248 string = ''
249 symbols = sorted(self.symbols())
250 i = 0
251 for symbol in symbols:
252 coefficient = self[symbol]
253 if coefficient == 1:
254 if i == 0:
255 string += symbol
256 else:
257 string += ' + {}'.format(symbol)
258 elif coefficient == -1:
259 if i == 0:
260 string += '-{}'.format(symbol)
261 else:
262 string += ' - {}'.format(symbol)
263 else:
264 if i == 0:
265 string += '{}*{}'.format(coefficient, symbol)
266 elif coefficient > 0:
267 string += ' + {}*{}'.format(coefficient, symbol)
268 else:
269 assert coefficient < 0
270 coefficient *= -1
271 string += ' - {}*{}'.format(coefficient, symbol)
272 i += 1
273 constant = self.constant
274 if constant != 0 and i == 0:
275 string += '{}'.format(constant)
276 elif constant > 0:
277 string += ' + {}'.format(constant)
278 elif constant < 0:
279 constant *= -1
280 string += ' - {}'.format(constant)
281 if string == '':
282 string = '0'
283 return string
284
285 def _parenstr(self, always=False):
286 string = str(self)
287 if not always and (self.isconstant() or self.issymbol()):
288 return string
289 else:
290 return '({})'.format(string)
291
292 def __repr__(self):
293 string = '{}({{'.format(self.__class__.__name__)
294 for i, (symbol, coefficient) in enumerate(self.coefficients()):
295 if i != 0:
296 string += ', '
297 string += '{!r}: {!r}'.format(symbol, coefficient)
298 string += '}}, {!r})'.format(self.constant)
299 return string
300
301 @classmethod
302 def fromstring(cls, string):
303 raise NotImplementedError
304
305 @_polymorphic_method
306 def __eq__(self, other):
307 # "normal" equality
308 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
309 return isinstance(other, Expression) and \
310 self._coefficients == other._coefficients and \
311 self.constant == other.constant
312
313 def __hash__(self):
314 return hash((self._coefficients, self._constant))
315
316 def _canonify(self):
317 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
318 [value.denominator for value in self.values()])
319 return self * lcm
320
321 @_polymorphic_method
322 def _eq(self, other):
323 return Polyhedron(equalities=[(self - other)._canonify()])
324
325 @_polymorphic_method
326 def __le__(self, other):
327 return Polyhedron(inequalities=[(self - other)._canonify()])
328
329 @_polymorphic_method
330 def __lt__(self, other):
331 return Polyhedron(inequalities=[(self - other)._canonify() + 1])
332
333 @_polymorphic_method
334 def __ge__(self, other):
335 return Polyhedron(inequalities=[(other - self)._canonify()])
336
337 @_polymorphic_method
338 def __gt__(self, other):
339 return Polyhedron(inequalities=[(other - self)._canonify() + 1])
340
341
342 def constant(numerator=0, denominator=None):
343 if denominator is None and isinstance(numerator, numbers.Rational):
344 return Expression(constant=3)
345 else:
346 return Expression(constant=Fraction(numerator, denominator))
347
348 def symbol(name):
349 if not isinstance(name, str):
350 raise TypeError('name must be a string')
351 return Expression(coefficients={name: 1})
352
353 def symbols(names):
354 if isinstance(names, str):
355 names = names.replace(',', ' ').split()
356 return (symbol(name) for name in names)
357
358
359 @_polymorphic_operator
360 def eq(a, b):
361 return a._eq(b)
362
363 @_polymorphic_operator
364 def le(a, b):
365 return a <= b
366
367 @_polymorphic_operator
368 def lt(a, b):
369 return a < b
370
371 @_polymorphic_operator
372 def ge(a, b):
373 return a >= b
374
375 @_polymorphic_operator
376 def gt(a, b):
377 return a > b
378
379
380 class Polyhedron:
381 """
382 This class implements polyhedrons.
383 """
384
385 def __new__(cls, equalities=None, inequalities=None):
386 if isinstance(equalities, str):
387 if inequalities is not None:
388 raise TypeError('too many arguments')
389 return cls.fromstring(equalities)
390 self = super().__new__(cls)
391 self._equalities = []
392 if equalities is not None:
393 for constraint in equalities:
394 for value in constraint.values():
395 if value.denominator != 1:
396 raise TypeError('non-integer constraint: '
397 '{} == 0'.format(constraint))
398 self._equalities.append(constraint)
399 self._inequalities = []
400 if inequalities is not None:
401 for constraint in inequalities:
402 for value in constraint.values():
403 if value.denominator != 1:
404 raise TypeError('non-integer constraint: '
405 '{} <= 0'.format(constraint))
406 self._inequalities.append(constraint)
407 self._bset = self.to_isl()
408 #print(self._bset)
409 #put this here just to test from isl method
410 #from_isl = self.from_isl(self._bset)
411 #print(from_isl)
412 #rint(self)
413 return self._bset
414
415
416 @property
417 def equalities(self):
418 yield from self._equalities
419
420 @property
421 def inequalities(self):
422 yield from self._inequalities
423
424 @property
425 def constant(self):
426 return self._constant
427
428 def isconstant(self):
429 return len(self._coefficients) == 0
430
431
432 def isempty(self):
433 return bool(libisl.isl_basic_set_is_empty(self._bset))
434
435 def constraints(self):
436 yield from self.equalities
437 yield from self.inequalities
438
439
440 def symbols(self):
441 s = set()
442 for constraint in self.constraints():
443 s.update(constraint.symbols)
444 yield from sorted(s)
445
446 @property
447 def dimension(self):
448 return len(self.symbols())
449
450 def __bool__(self):
451 # return false if the polyhedron is empty, true otherwise
452 if self._equalities or self._inequalities:
453 return False
454 else:
455 return True
456
457
458 def __contains__(self, value):
459 # is the value in the polyhedron?
460 raise NotImplementedError
461
462 def __eq__(self, other):
463 raise NotImplementedError
464
465 def is_empty(self):
466 return
467
468 def isuniverse(self):
469 return self == universe
470
471 def isdisjoint(self, other):
472 # return true if the polyhedron has no elements in common with other
473 raise NotImplementedError
474
475 def issubset(self, other):
476 raise NotImplementedError
477
478 def __le__(self, other):
479 return self.issubset(other)
480
481 def __lt__(self, other):
482 raise NotImplementedError
483
484 def issuperset(self, other):
485 # test whether every element in other is in the polyhedron
486 for value in other:
487 if value == self.constraints():
488 return True
489 else:
490 return False
491 raise NotImplementedError
492
493 def __ge__(self, other):
494 return self.issuperset(other)
495
496 def __gt__(self, other):
497 raise NotImplementedError
498
499 def union(self, *others):
500 # return a new polyhedron with elements from the polyhedron and all
501 # others (convex union)
502 raise NotImplementedError
503
504 def __or__(self, other):
505 return self.union(other)
506
507 def intersection(self, *others):
508 # return a new polyhedron with elements common to the polyhedron and all
509 # others
510 # a poor man's implementation could be:
511 # equalities = list(self.equalities)
512 # inequalities = list(self.inequalities)
513 # for other in others:
514 # equalities.extend(other.equalities)
515 # inequalities.extend(other.inequalities)
516 # return self.__class__(equalities, inequalities)
517 raise NotImplementedError
518
519 def __and__(self, other):
520 return self.intersection(other)
521
522 def difference(self, *others):
523 # return a new polyhedron with elements in the polyhedron that are not
524 # in the others
525 raise NotImplementedError
526
527 def __sub__(self, other):
528 return self.difference(other)
529
530 def __str__(self):
531 constraints = []
532 for constraint in self.equalities:
533 constraints.append('{} == 0'.format(constraint))
534 for constraint in self.inequalities:
535 constraints.append('{} <= 0'.format(constraint))
536 return '{{{}}}'.format(', '.join(constraints))
537
538 def __repr__(self):
539 equalities = list(self.equalities)
540 inequalities = list(self.inequalities)
541 return '{}(equalities={!r}, inequalities={!r})' \
542 ''.format(self.__class__.__name__, equalities, inequalities)
543
544 @classmethod
545 def fromstring(cls, string):
546 raise NotImplementedError
547
548 def to_isl(self):
549 #d = Expression().__dict__ #write expression values to dictionary in form {'_constant': value, '_coefficients': value}
550 d = {'_constant': 2, '_coefficients': {'b':1}}
551 coeff = d.get('_coefficients')
552 num_coefficients = len(coeff)
553 space = libisl.isl_space_set_alloc(Context(), 0, num_coefficients)
554 bset = libisl.isl_basic_set_empty(libisl.isl_space_copy(space))
555 ls = libisl.isl_local_space_from_space(libisl.isl_space_copy(space))
556 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
557 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
558 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set
559 need to change the symbols method to a lookup table for the integer value for each letter that could be a symbol'''
560 if self._equalities:
561 if '_constant' in d:
562 value = d.get('_constant')
563 ceq = libisl.isl_constraint_set_constant_si(ceq, value)
564 if '_coefficients' in d:
565 value_co = d.get('_coefficients')
566 for co in value_co:
567 num = value_co.get(co)
568 ceq = libisl.isl_constraint_set_coefficient_si(ceq, islhelper.isl_dim_set, get_ids(co), num)
569 bset = libisl.isl_set_add_constraint(bset, ceq)
570
571 if self._inequalities:
572 if '_constant' in d:
573 value = d.get('_constant')
574 cin = libisl.isl_constraint_set_constant_si(cin, value)
575 if '_coefficients' in d:
576 value_co = d.get('_coefficients')
577 for co in value_co:
578 num = value_co.get(co)
579 if value_co: #if dictionary not empty add coefficient as to constraint
580 cin = libisl.isl_constraint_set_coefficient_si(cin, islhelper.isl_dim_set, get_ids(co), num)
581 bset = libisl.isl_set_add_constraint(bset, cin)
582 ip = libisl.isl_printer_to_str(Context()) #create string printer
583 ip = libisl.isl_printer_print_set(ip, bset) #print set to printer
584 string = libisl.isl_printer_get_str(ip) #get string from printer
585 string = str(string)
586 print(string)
587 return string
588
589
590 def from_isl(self, bset):
591 '''takes basic set in isl form and puts back into python version of polyhedron
592 isl example code gives idl form as:
593 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}");'''
594
595 poly = 0
596 return poly
597
598 empty = eq(1,1)
599
600
601 universe = Polyhedron()