Simplify BasicSet.__str__
[linpy.git] / pypol / linear.py
1 import ctypes, ctypes.util
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7 from . import isl
8 from .isl import libisl
9
10
11 __all__ = [
12 'Expression',
13 'constant', 'symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'empty', 'universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 _main_ctx = isl.Context()
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 def __new__(cls, coefficients=None, constant=0):
54 if isinstance(coefficients, str):
55 if constant:
56 raise TypeError('too many arguments')
57 return cls.fromstring(coefficients)
58 self = super().__new__(cls)
59 self._coefficients = {}
60 if isinstance(coefficients, dict):
61 coefficients = coefficients.items()
62 if coefficients is not None:
63 for symbol, coefficient in coefficients:
64 if isinstance(symbol, Expression) and symbol.issymbol():
65 symbol = str(symbol)
66 elif not isinstance(symbol, str):
67 raise TypeError('symbols must be strings')
68 if not isinstance(coefficient, numbers.Rational):
69 raise TypeError('coefficients must be rational numbers')
70 if coefficient != 0:
71 self._coefficients[symbol] = coefficient
72 if not isinstance(constant, numbers.Rational):
73 raise TypeError('constant must be a rational number')
74 self._constant = constant
75 return self
76
77
78 def symbols(self):
79 yield from sorted(self._coefficients)
80
81 @property
82 def dimension(self):
83 return len(list(self.symbols()))
84
85 def coefficient(self, symbol):
86 if isinstance(symbol, Expression) and symbol.issymbol():
87 symbol = str(symbol)
88 elif not isinstance(symbol, str):
89 raise TypeError('symbol must be a string')
90 try:
91 return self._coefficients[symbol]
92 except KeyError:
93 return 0
94
95 __getitem__ = coefficient
96
97 @property
98 def coefficients(self):
99 for symbol in self.symbols():
100 yield symbol, self.coefficient(symbol)
101
102 @property
103 def constant(self):
104 return self._constant
105
106 def isconstant(self):
107 return len(self._coefficients) == 0
108
109 def values(self):
110 for symbol in self.symbols():
111 yield self.coefficient(symbol)
112 yield self.constant
113
114 def values_int(self):
115 for symbol in self.symbols():
116 return self.coefficient(symbol)
117 return int(self.constant)
118
119 def symbol(self):
120 if not self.issymbol():
121 raise ValueError('not a symbol: {}'.format(self))
122 for symbol in self.symbols():
123 return symbol
124
125 def issymbol(self):
126 return len(self._coefficients) == 1 and self._constant == 0
127
128 def __bool__(self):
129 return (not self.isconstant()) or bool(self.constant)
130
131 def __pos__(self):
132 return self
133
134 def __neg__(self):
135 return self * -1
136
137 @_polymorphic_method
138 def __add__(self, other):
139 coefficients = dict(self.coefficients)
140 for symbol, coefficient in other.coefficients:
141 if symbol in coefficients:
142 coefficients[symbol] += coefficient
143 else:
144 coefficients[symbol] = coefficient
145 constant = self.constant + other.constant
146 return Expression(coefficients, constant)
147
148 __radd__ = __add__
149
150 @_polymorphic_method
151 def __sub__(self, other):
152 coefficients = dict(self.coefficients)
153 for symbol, coefficient in other.coefficients:
154 if symbol in coefficients:
155 coefficients[symbol] -= coefficient
156 else:
157 coefficients[symbol] = -coefficient
158 constant = self.constant - other.constant
159 return Expression(coefficients, constant)
160
161 def __rsub__(self, other):
162 return -(self - other)
163
164 @_polymorphic_method
165 def __mul__(self, other):
166 if other.isconstant():
167 coefficients = dict(self.coefficients)
168 for symbol in coefficients:
169 coefficients[symbol] *= other.constant
170 constant = self.constant * other.constant
171 return Expression(coefficients, constant)
172 if isinstance(other, Expression) and not self.isconstant():
173 raise ValueError('non-linear expression: '
174 '{} * {}'.format(self._parenstr(), other._parenstr()))
175 return NotImplemented
176
177 __rmul__ = __mul__
178
179 @_polymorphic_method
180 def __truediv__(self, other):
181 if other.isconstant():
182 coefficients = dict(self.coefficients())
183 for symbol in coefficients:
184 coefficients[symbol] = \
185 Fraction(coefficients[symbol], other.constant)
186 constant = Fraction(self.constant, other.constant)
187 return Expression(coefficients, constant)
188 if isinstance(other, Expression):
189 raise ValueError('non-linear expression: '
190 '{} / {}'.format(self._parenstr(), other._parenstr()))
191 return NotImplemented
192
193 def __rtruediv__(self, other):
194 if isinstance(other, self):
195 if self.isconstant():
196 constant = Fraction(other, self.constant)
197 return Expression(constant=constant)
198 else:
199 raise ValueError('non-linear expression: '
200 '{} / {}'.format(other._parenstr(), self._parenstr()))
201 return NotImplemented
202
203 def __str__(self):
204 string = ''
205 symbols = sorted(self.symbols())
206 i = 0
207 for symbol in symbols:
208 coefficient = self[symbol]
209 if coefficient == 1:
210 if i == 0:
211 string += symbol
212 else:
213 string += ' + {}'.format(symbol)
214 elif coefficient == -1:
215 if i == 0:
216 string += '-{}'.format(symbol)
217 else:
218 string += ' - {}'.format(symbol)
219 else:
220 if i == 0:
221 string += '{}*{}'.format(coefficient, symbol)
222 elif coefficient > 0:
223 string += ' + {}*{}'.format(coefficient, symbol)
224 else:
225 assert coefficient < 0
226 coefficient *= -1
227 string += ' - {}*{}'.format(coefficient, symbol)
228 i += 1
229 constant = self.constant
230 if constant != 0 and i == 0:
231 string += '{}'.format(constant)
232 elif constant > 0:
233 string += ' + {}'.format(constant)
234 elif constant < 0:
235 constant *= -1
236 string += ' - {}'.format(constant)
237 if string == '':
238 string = '0'
239 return string
240
241 def _parenstr(self, always=False):
242 string = str(self)
243 if not always and (self.isconstant() or self.issymbol()):
244 return string
245 else:
246 return '({})'.format(string)
247
248 def __repr__(self):
249 string = '{}({{'.format(self.__class__.__name__)
250 for i, (symbol, coefficient) in enumerate(self.coefficients):
251 if i != 0:
252 string += ', '
253 string += '{!r}: {!r}'.format(symbol, coefficient)
254 string += '}}, {!r})'.format(self.constant)
255 return string
256
257 @classmethod
258 def fromstring(cls, string):
259 raise NotImplementedError
260
261 @_polymorphic_method
262 def __eq__(self, other):
263 # "normal" equality
264 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
265 return isinstance(other, Expression) and \
266 self._coefficients == other._coefficients and \
267 self.constant == other.constant
268
269 def __hash__(self):
270 return hash((self._coefficients, self._constant))
271
272 def _canonify(self):
273 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
274 [value.denominator for value in self.values()])
275 return self * lcm
276
277 @_polymorphic_method
278 def _eq(self, other):
279 return Polyhedron(equalities=[(self - other)._canonify()])
280
281 @_polymorphic_method
282 def __le__(self, other):
283 return Polyhedron(inequalities=[(other - self)._canonify()])
284
285 @_polymorphic_method
286 def __lt__(self, other):
287 return Polyhedron(inequalities=[(other - self)._canonify() - 1])
288
289 @_polymorphic_method
290 def __ge__(self, other):
291 return Polyhedron(inequalities=[(self - other)._canonify()])
292
293 @_polymorphic_method
294 def __gt__(self, other):
295 return Polyhedron(inequalities=[(self - other)._canonify() - 1])
296
297
298 def constant(numerator=0, denominator=None):
299 if denominator is None and isinstance(numerator, numbers.Rational):
300 return Expression(constant=numerator)
301 else:
302 return Expression(constant=Fraction(numerator, denominator))
303
304 def symbol(name):
305 if not isinstance(name, str):
306 raise TypeError('name must be a string')
307 return Expression(coefficients={name: 1})
308
309 def symbols(names):
310 if isinstance(names, str):
311 names = names.replace(',', ' ').split()
312 return (symbol(name) for name in names)
313
314
315 @_polymorphic_operator
316 def eq(a, b):
317 return a._eq(b)
318
319 @_polymorphic_operator
320 def le(a, b):
321 return a <= b
322
323 @_polymorphic_operator
324 def lt(a, b):
325 return a < b
326
327 @_polymorphic_operator
328 def ge(a, b):
329 return a >= b
330
331 @_polymorphic_operator
332 def gt(a, b):
333 return a > b
334
335
336 class Polyhedron:
337 """
338 This class implements polyhedrons.
339 """
340
341 def __new__(cls, equalities=None, inequalities=None):
342 if isinstance(equalities, str):
343 if inequalities is not None:
344 raise TypeError('too many arguments')
345 return cls.fromstring(equalities)
346 self = super().__new__(cls)
347 self._equalities = []
348 if equalities is not None:
349 for constraint in equalities:
350 for value in constraint.values():
351 if value.denominator != 1:
352 raise TypeError('non-integer constraint: '
353 '{} == 0'.format(constraint))
354 self._equalities.append(constraint)
355 self._inequalities = []
356 if inequalities is not None:
357 for constraint in inequalities:
358 for value in constraint.values():
359 if value.denominator != 1:
360 raise TypeError('non-integer constraint: '
361 '{} <= 0'.format(constraint))
362 self._inequalities.append(constraint)
363 return self
364
365 @property
366 def equalities(self):
367 yield from self._equalities
368
369 @property
370 def inequalities(self):
371 yield from self._inequalities
372
373 @property
374 def constant(self):
375 return self._constant
376
377 def isconstant(self):
378 return len(self._coefficients) == 0
379
380 def isempty(self):
381 return bool(libisl.isl_basic_set_is_empty(self._bset))
382
383 def constraints(self):
384 yield from self.equalities
385 yield from self.inequalities
386
387 def symbols(self):
388 s = set()
389 for constraint in self.constraints():
390 s.update(constraint.symbols())
391 return sorted(s)
392
393 @property
394 def dimension(self):
395 return len(self.symbols())
396
397 def __bool__(self):
398 # return false if the polyhedron is empty, true otherwise
399 if self._equalities or self._inequalities:
400 return False
401 else:
402 return True
403
404 def __contains__(self, value):
405 # is the value in the polyhedron?
406 raise NotImplementedError
407
408 def __eq__(self, other):
409 raise NotImplementedError
410
411 def is_empty(self):
412 return
413
414 def isuniverse(self):
415 return self == universe
416
417 def isdisjoint(self, other):
418 # return true if the polyhedron has no elements in common with other
419 raise NotImplementedError
420
421 def issubset(self, other):
422 raise NotImplementedError
423
424 def __le__(self, other):
425 return self.issubset(other)
426
427 def __lt__(self, other):
428 raise NotImplementedError
429
430 def issuperset(self, other):
431 # test whether every element in other is in the polyhedron
432 for value in other:
433 if value == self.constraints():
434 return True
435 else:
436 return False
437 raise NotImplementedError
438
439 def __ge__(self, other):
440 return self.issuperset(other)
441
442 def __gt__(self, other):
443 raise NotImplementedError
444
445 def union(self, *others):
446 # return a new polyhedron with elements from the polyhedron and all
447 # others (convex union)
448 raise NotImplementedError
449
450 def __or__(self, other):
451 return self.union(other)
452
453 def intersection(self, *others):
454 # return a new polyhedron with elements common to the polyhedron and all
455 # others
456 # a poor man's implementation could be:
457 # equalities = list(self.equalities)
458 # inequalities = list(self.inequalities)
459 # for other in others:
460 # equalities.extend(other.equalities)
461 # inequalities.extend(other.inequalities)
462 # return self.__class__(equalities, inequalities)
463 raise NotImplementedError
464
465 def __and__(self, other):
466 return self.intersection(other)
467
468 def difference(self, *others):
469 # return a new polyhedron with elements in the polyhedron that are not
470 # in the others
471 raise NotImplementedError
472
473 def __sub__(self, other):
474 return self.difference(other)
475
476 def __str__(self):
477 constraints = []
478 for constraint in self.equalities:
479 constraints.append('{} == 0'.format(constraint))
480 for constraint in self.inequalities:
481 constraints.append('{} >= 0'.format(constraint))
482 return '{{{}}}'.format(', '.join(constraints))
483
484 def __repr__(self):
485 equalities = list(self.equalities)
486 inequalities = list(self.inequalities)
487 return '{}(equalities={!r}, inequalities={!r})' \
488 ''.format(self.__class__.__name__, equalities, inequalities)
489
490 @classmethod
491 def fromstring(cls, string):
492 raise NotImplementedError
493
494 def _symbolunion(self, *others):
495 symbols = set(self.symbols())
496 for other in others:
497 symbols.update(other.symbols())
498 return sorted(symbols)
499
500 def _to_isl(self, symbols=None):
501 if symbols is None:
502 symbols = self.symbols()
503 num_coefficients = len(symbols)
504 space = libisl.isl_space_set_alloc(_main_ctx, 0, num_coefficients)
505 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
506 ls = libisl.isl_local_space_from_space(space)
507 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
508 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
509 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
510 if list(self.equalities): #check if any equalities exist
511 for eq in self.equalities:
512 coeff_eq = dict(eq.coefficients)
513 if eq.constant:
514 value = eq.constant
515 ceq = libisl.isl_constraint_set_constant_si(ceq, value)
516 for eq in coeff_eq:
517 num = coeff_eq.get(eq)
518 iden = symbols.index(eq)
519 ceq = libisl.isl_constraint_set_coefficient_si(ceq, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
520 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
521 if list(self.inequalities): #check if any inequalities exist
522 for ineq in self.inequalities:
523 coeff_in = dict(ineq.coefficients)
524 if ineq.constant:
525 value = ineq.constant
526 cin = libisl.isl_constraint_set_constant_si(cin, value)
527 for ineq in coeff_in:
528 num = coeff_in.get(ineq)
529 iden = symbols.index(ineq)
530 cin = libisl.isl_constraint_set_coefficient_si(cin, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
531 bset = libisl.isl_basic_set_add_constraint(bset, cin)
532 bset = isl.BasicSet(bset)
533 return bset
534
535 def from_isl(self, bset):
536 '''takes basic set in isl form and puts back into python version of polyhedron
537 isl example code gives isl form as:
538 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
539 our printer is giving form as:
540 b'{ [i0] : 1 = 0 }' '''
541 #bset = self
542 if self._equalities:
543 constraints = libisl.isl_basic_set_equalities_matrix(bset, 3)
544 elif self._inequalities:
545 constraints = libisl.isl_basic_set_inequalities_matrix(bset, 3)
546 print(constraints)
547 return constraints
548
549 empty = None #eq(0,1)
550 universe = None #Polyhedron()
551
552 if __name__ == '__main__':
553 ex1 = Expression(coefficients={'a': 1, 'x': 2}, constant=2)
554 ex2 = Expression(coefficients={'a': 3 , 'b': 2}, constant=3)
555 p = Polyhedron(inequalities=[ex1, ex2])
556 bs = p._to_isl()
557 print(bs)