348294ce808d0306890ffb548ab28c628448b9e6
[linpy.git] / pypol / linear.py
1 import ctypes, ctypes.util
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7 from . import isl, islhelper
8 from .isl import libisl, Context, BasicSet
9
10
11 __all__ = [
12 'Expression',
13 'constant', 'symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'empty', 'universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 class Expression:
46 """
47 This class implements linear expressions.
48 """
49
50 def __new__(cls, coefficients=None, constant=0):
51 if isinstance(coefficients, str):
52 if constant:
53 raise TypeError('too many arguments')
54 return cls.fromstring(coefficients)
55 self = super().__new__(cls)
56 self._coefficients = {}
57 if isinstance(coefficients, dict):
58 coefficients = coefficients.items()
59 if coefficients is not None:
60 for symbol, coefficient in coefficients:
61 if isinstance(symbol, Expression) and symbol.issymbol():
62 symbol = str(symbol)
63 elif not isinstance(symbol, str):
64 raise TypeError('symbols must be strings')
65 if not isinstance(coefficient, numbers.Rational):
66 raise TypeError('coefficients must be rational numbers')
67 if coefficient != 0:
68 self._coefficients[symbol] = coefficient
69 if not isinstance(constant, numbers.Rational):
70 raise TypeError('constant must be a rational number')
71 self._constant = constant
72 return self
73
74
75 def symbols(self):
76 yield from sorted(self._coefficients)
77
78 @property
79 def dimension(self):
80 return len(list(self.symbols()))
81
82 def coefficient(self, symbol):
83 if isinstance(symbol, Expression) and symbol.issymbol():
84 symbol = str(symbol)
85 elif not isinstance(symbol, str):
86 raise TypeError('symbol must be a string')
87 try:
88 return self._coefficients[symbol]
89 except KeyError:
90 return 0
91
92 __getitem__ = coefficient
93
94 @property
95 def coefficients(self):
96 for symbol in self.symbols():
97 yield symbol, self.coefficient(symbol)
98
99 @property
100 def constant(self):
101 return self._constant
102
103 def isconstant(self):
104 return len(self._coefficients) == 0
105
106 def values(self):
107 for symbol in self.symbols():
108 yield self.coefficient(symbol)
109 yield self.constant
110
111 def values_int(self):
112 for symbol in self.symbols():
113 return self.coefficient(symbol)
114 return int(self.constant)
115
116 def symbol(self):
117 if not self.issymbol():
118 raise ValueError('not a symbol: {}'.format(self))
119 for symbol in self.symbols():
120 return symbol
121
122 def issymbol(self):
123 return len(self._coefficients) == 1 and self._constant == 0
124
125 def __bool__(self):
126 return (not self.isconstant()) or bool(self.constant)
127
128 def __pos__(self):
129 return self
130
131 def __neg__(self):
132 return self * -1
133
134 @_polymorphic_method
135 def __add__(self, other):
136 coefficients = dict(self.coefficients)
137 for symbol, coefficient in other.coefficients:
138 if symbol in coefficients:
139 coefficients[symbol] += coefficient
140 else:
141 coefficients[symbol] = coefficient
142 constant = self.constant + other.constant
143 return Expression(coefficients, constant)
144
145 __radd__ = __add__
146
147 @_polymorphic_method
148 def __sub__(self, other):
149 coefficients = dict(self.coefficients)
150 for symbol, coefficient in other.coefficients:
151 if symbol in coefficients:
152 coefficients[symbol] -= coefficient
153 else:
154 coefficients[symbol] = -coefficient
155 constant = self.constant - other.constant
156 return Expression(coefficients, constant)
157
158 def __rsub__(self, other):
159 return -(self - other)
160
161 @_polymorphic_method
162 def __mul__(self, other):
163 if other.isconstant():
164 coefficients = dict(self.coefficients)
165 for symbol in coefficients:
166 coefficients[symbol] *= other.constant
167 constant = self.constant * other.constant
168 return Expression(coefficients, constant)
169 if isinstance(other, Expression) and not self.isconstant():
170 raise ValueError('non-linear expression: '
171 '{} * {}'.format(self._parenstr(), other._parenstr()))
172 return NotImplemented
173
174 __rmul__ = __mul__
175
176 @_polymorphic_method
177 def __truediv__(self, other):
178 if other.isconstant():
179 coefficients = dict(self.coefficients())
180 for symbol in coefficients:
181 coefficients[symbol] = \
182 Fraction(coefficients[symbol], other.constant)
183 constant = Fraction(self.constant, other.constant)
184 return Expression(coefficients, constant)
185 if isinstance(other, Expression):
186 raise ValueError('non-linear expression: '
187 '{} / {}'.format(self._parenstr(), other._parenstr()))
188 return NotImplemented
189
190 def __rtruediv__(self, other):
191 if isinstance(other, self):
192 if self.isconstant():
193 constant = Fraction(other, self.constant)
194 return Expression(constant=constant)
195 else:
196 raise ValueError('non-linear expression: '
197 '{} / {}'.format(other._parenstr(), self._parenstr()))
198 return NotImplemented
199
200 def __str__(self):
201 string = ''
202 symbols = sorted(self.symbols())
203 i = 0
204 for symbol in symbols:
205 coefficient = self[symbol]
206 if coefficient == 1:
207 if i == 0:
208 string += symbol
209 else:
210 string += ' + {}'.format(symbol)
211 elif coefficient == -1:
212 if i == 0:
213 string += '-{}'.format(symbol)
214 else:
215 string += ' - {}'.format(symbol)
216 else:
217 if i == 0:
218 string += '{}*{}'.format(coefficient, symbol)
219 elif coefficient > 0:
220 string += ' + {}*{}'.format(coefficient, symbol)
221 else:
222 assert coefficient < 0
223 coefficient *= -1
224 string += ' - {}*{}'.format(coefficient, symbol)
225 i += 1
226 constant = self.constant
227 if constant != 0 and i == 0:
228 string += '{}'.format(constant)
229 elif constant > 0:
230 string += ' + {}'.format(constant)
231 elif constant < 0:
232 constant *= -1
233 string += ' - {}'.format(constant)
234 if string == '':
235 string = '0'
236 return string
237
238 def _parenstr(self, always=False):
239 string = str(self)
240 if not always and (self.isconstant() or self.issymbol()):
241 return string
242 else:
243 return '({})'.format(string)
244
245 def __repr__(self):
246 string = '{}({{'.format(self.__class__.__name__)
247 for i, (symbol, coefficient) in enumerate(self.coefficients):
248 if i != 0:
249 string += ', '
250 string += '{!r}: {!r}'.format(symbol, coefficient)
251 string += '}}, {!r})'.format(self.constant)
252 return string
253
254 @classmethod
255 def fromstring(cls, string):
256 raise NotImplementedError
257
258 @_polymorphic_method
259 def __eq__(self, other):
260 # "normal" equality
261 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
262 return isinstance(other, Expression) and \
263 self._coefficients == other._coefficients and \
264 self.constant == other.constant
265
266 def __hash__(self):
267 return hash((self._coefficients, self._constant))
268
269 def _canonify(self):
270 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
271 [value.denominator for value in self.values()])
272 return self * lcm
273
274 @_polymorphic_method
275 def _eq(self, other):
276 return Polyhedron(equalities=[(self - other)._canonify()])
277
278 @_polymorphic_method
279 def __le__(self, other):
280 return Polyhedron(inequalities=[(other - self)._canonify()])
281
282 @_polymorphic_method
283 def __lt__(self, other):
284 return Polyhedron(inequalities=[(other - self)._canonify() - 1])
285
286 @_polymorphic_method
287 def __ge__(self, other):
288 return Polyhedron(inequalities=[(self - other)._canonify()])
289
290 @_polymorphic_method
291 def __gt__(self, other):
292 return Polyhedron(inequalities=[(self - other)._canonify() - 1])
293
294
295 def constant(numerator=0, denominator=None):
296 if denominator is None and isinstance(numerator, numbers.Rational):
297 return Expression(constant=numerator)
298 else:
299 return Expression(constant=Fraction(numerator, denominator))
300
301 def symbol(name):
302 if not isinstance(name, str):
303 raise TypeError('name must be a string')
304 return Expression(coefficients={name: 1})
305
306 def symbols(names):
307 if isinstance(names, str):
308 names = names.replace(',', ' ').split()
309 return (symbol(name) for name in names)
310
311
312 @_polymorphic_operator
313 def eq(a, b):
314 return a._eq(b)
315
316 @_polymorphic_operator
317 def le(a, b):
318 return a <= b
319
320 @_polymorphic_operator
321 def lt(a, b):
322 return a < b
323
324 @_polymorphic_operator
325 def ge(a, b):
326 return a >= b
327
328 @_polymorphic_operator
329 def gt(a, b):
330 return a > b
331
332
333 class Polyhedron:
334 """
335 This class implements polyhedrons.
336 """
337
338 def __new__(cls, equalities=None, inequalities=None):
339 if isinstance(equalities, str):
340 if inequalities is not None:
341 raise TypeError('too many arguments')
342 return cls.fromstring(equalities)
343 self = super().__new__(cls)
344 self._equalities = []
345 if equalities is not None:
346 for constraint in equalities:
347 for value in constraint.values():
348 if value.denominator != 1:
349 raise TypeError('non-integer constraint: '
350 '{} == 0'.format(constraint))
351 self._equalities.append(constraint)
352 self._inequalities = []
353 if inequalities is not None:
354 for constraint in inequalities:
355 for value in constraint.values():
356 if value.denominator != 1:
357 raise TypeError('non-integer constraint: '
358 '{} <= 0'.format(constraint))
359 self._inequalities.append(constraint)
360 return self
361
362 @property
363 def equalities(self):
364 yield from self._equalities
365
366 @property
367 def inequalities(self):
368 yield from self._inequalities
369
370 @property
371 def constant(self):
372 return self._constant
373
374 def isconstant(self):
375 return len(self._coefficients) == 0
376
377 def isempty(self):
378 return bool(libisl.isl_basic_set_is_empty(self._bset))
379
380 def constraints(self):
381 yield from self.equalities
382 yield from self.inequalities
383
384 def symbols(self):
385 s = set()
386 for constraint in self.constraints():
387 s.update(constraint.symbols())
388 return sorted(s)
389
390 @property
391 def dimension(self):
392 return len(self.symbols())
393
394 def __bool__(self):
395 # return false if the polyhedron is empty, true otherwise
396 if self._equalities or self._inequalities:
397 return False
398 else:
399 return True
400
401 def __contains__(self, value):
402 # is the value in the polyhedron?
403 raise NotImplementedError
404
405 def __eq__(self, other):
406 raise NotImplementedError
407
408 def is_empty(self):
409 return
410
411 def isuniverse(self):
412 return self == universe
413
414 def isdisjoint(self, other):
415 # return true if the polyhedron has no elements in common with other
416 raise NotImplementedError
417
418 def issubset(self, other):
419 raise NotImplementedError
420
421 def __le__(self, other):
422 return self.issubset(other)
423
424 def __lt__(self, other):
425 raise NotImplementedError
426
427 def issuperset(self, other):
428 # test whether every element in other is in the polyhedron
429 for value in other:
430 if value == self.constraints():
431 return True
432 else:
433 return False
434 raise NotImplementedError
435
436 def __ge__(self, other):
437 return self.issuperset(other)
438
439 def __gt__(self, other):
440 raise NotImplementedError
441
442 def union(self, *others):
443 # return a new polyhedron with elements from the polyhedron and all
444 # others (convex union)
445 raise NotImplementedError
446
447 def __or__(self, other):
448 return self.union(other)
449
450 def intersection(self, *others):
451 # return a new polyhedron with elements common to the polyhedron and all
452 # others
453 # a poor man's implementation could be:
454 # equalities = list(self.equalities)
455 # inequalities = list(self.inequalities)
456 # for other in others:
457 # equalities.extend(other.equalities)
458 # inequalities.extend(other.inequalities)
459 # return self.__class__(equalities, inequalities)
460 raise NotImplementedError
461
462 def __and__(self, other):
463 return self.intersection(other)
464
465 def difference(self, *others):
466 # return a new polyhedron with elements in the polyhedron that are not
467 # in the others
468 raise NotImplementedError
469
470 def __sub__(self, other):
471 return self.difference(other)
472
473 def __str__(self):
474 constraints = []
475 for constraint in self.equalities:
476 constraints.append('{} == 0'.format(constraint))
477 for constraint in self.inequalities:
478 constraints.append('{} >= 0'.format(constraint))
479 return '{{{}}}'.format(', '.join(constraints))
480
481 def __repr__(self):
482 equalities = list(self.equalities)
483 inequalities = list(self.inequalities)
484 return '{}(equalities={!r}, inequalities={!r})' \
485 ''.format(self.__class__.__name__, equalities, inequalities)
486
487 @classmethod
488 def fromstring(cls, string):
489 raise NotImplementedError
490
491 def _symbolunion(self, *others):
492 symbols = set(self.symbols())
493 for other in others:
494 symbols.update(other.symbols())
495 return sorted(symbols)
496
497 def _to_isl(self, symbols=None):
498 if symbols is None:
499 symbols = self.symbols()
500 num_coefficients = len(symbols)
501 ctx = Context()
502 space = libisl.isl_space_set_alloc(ctx, 0, num_coefficients)
503 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
504 ls = libisl.isl_local_space_from_space(space)
505 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
506 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
507 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
508 if list(self.equalities): #check if any equalities exist
509 for eq in self.equalities:
510 coeff_eq = dict(eq.coefficients)
511 if eq.constant:
512 value = eq.constant
513 ceq = libisl.isl_constraint_set_constant_si(ceq, value)
514 for eq in coeff_eq:
515 num = coeff_eq.get(eq)
516 iden = symbols.index(eq)
517 ceq = libisl.isl_constraint_set_coefficient_si(ceq, islhelper.isl_dim_set, iden, num) #use 3 for type isl_dim_set
518 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
519 if list(self.inequalities): #check if any inequalities exist
520 for ineq in self.inequalities:
521 coeff_in = dict(ineq.coefficients)
522 if ineq.constant:
523 value = ineq.constant
524 cin = libisl.isl_constraint_set_constant_si(cin, value)
525 for ineq in coeff_in:
526 num = coeff_in.get(ineq)
527 iden = symbols.index(ineq)
528 cin = libisl.isl_constraint_set_coefficient_si(cin, islhelper.isl_dim_set, iden, num) #use 3 for type isl_dim_set
529 bset = libisl.isl_basic_set_add_constraint(bset, cin)
530 bset = BasicSet(bset)
531 return bset
532
533 def from_isl(self, bset):
534 '''takes basic set in isl form and puts back into python version of polyhedron
535 isl example code gives isl form as:
536 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
537 our printer is giving form as:
538 b'{ [i0] : 1 = 0 }' '''
539 #bset = self
540 if self._equalities:
541 constraints = libisl.isl_basic_set_equalities_matrix(bset, 3)
542 elif self._inequalities:
543 constraints = libisl.isl_basic_set_inequalities_matrix(bset, 3)
544 print(constraints)
545 return constraints
546
547 empty = None #eq(0,1)
548 universe = None #Polyhedron()
549
550 if __name__ == '__main__':
551 ex1 = Expression(coefficients={'a': 1, 'x': 2}, constant=2)
552 ex2 = Expression(coefficients={'a': 3 , 'b': 2}, constant=3)
553 p = Polyhedron(inequalities=[ex1, ex2])
554 bs = p._to_isl()
555 print(bs)