Clean up an import in domains.py
[linpy.git] / pypol / polyhedra.py
1
2 import functools
3 import math
4 import numbers
5
6 from . import islhelper
7
8 from .islhelper import mainctx, libisl
9 from .geometry import GeometricObject
10 from .coordinates import Point
11 from .linexprs import Expression, Symbol, Rational
12 from .domains import Domain
13
14
15 __all__ = [
16 'Polyhedron',
17 'Lt', 'Le', 'Eq', 'Ne', 'Ge', 'Gt',
18 'Empty', 'Universe',
19 ]
20
21
22 class Polyhedron(Domain):
23
24 __slots__ = (
25 '_equalities',
26 '_inequalities',
27 '_constraints',
28 '_symbols',
29 '_dimension',
30 )
31
32 def __new__(cls, equalities=None, inequalities=None):
33 if isinstance(equalities, str):
34 if inequalities is not None:
35 raise TypeError('too many arguments')
36 return cls.fromstring(equalities)
37 elif isinstance(equalities, GeometricObject):
38 if inequalities is not None:
39 raise TypeError('too many arguments')
40 return equalities.aspolyhedron()
41 if equalities is None:
42 equalities = []
43 else:
44 for i, equality in enumerate(equalities):
45 if not isinstance(equality, Expression):
46 raise TypeError('equalities must be linear expressions')
47 equalities[i] = equality.scaleint()
48 if inequalities is None:
49 inequalities = []
50 else:
51 for i, inequality in enumerate(inequalities):
52 if not isinstance(inequality, Expression):
53 raise TypeError('inequalities must be linear expressions')
54 inequalities[i] = inequality.scaleint()
55 symbols = cls._xsymbols(equalities + inequalities)
56 islbset = cls._toislbasicset(equalities, inequalities, symbols)
57 return cls._fromislbasicset(islbset, symbols)
58
59 @property
60 def equalities(self):
61 return self._equalities
62
63 @property
64 def inequalities(self):
65 return self._inequalities
66
67 @property
68 def constraints(self):
69 return self._constraints
70
71 @property
72 def polyhedra(self):
73 return self,
74
75 def disjoint(self):
76 return self
77
78 def isuniverse(self):
79 islbset = self._toislbasicset(self.equalities, self.inequalities,
80 self.symbols)
81 universe = bool(libisl.isl_basic_set_is_universe(islbset))
82 libisl.isl_basic_set_free(islbset)
83 return universe
84
85 def aspolyhedron(self):
86 return self
87
88 def __contains__(self, point):
89 if not isinstance(point, Point):
90 raise TypeError('point must be a Point instance')
91 if self.symbols != point.symbols:
92 raise ValueError('arguments must belong to the same space')
93 for equality in self.equalities:
94 if equality.subs(point.coordinates()) != 0:
95 return False
96 for inequality in self.inequalities:
97 if inequality.subs(point.coordinates()) < 0:
98 return False
99 return True
100
101 def subs(self, symbol, expression=None):
102 equalities = [equality.subs(symbol, expression)
103 for equality in self.equalities]
104 inequalities = [inequality.subs(symbol, expression)
105 for inequality in self.inequalities]
106 return Polyhedron(equalities, inequalities)
107
108 @classmethod
109 def _fromislbasicset(cls, islbset, symbols):
110 islconstraints = islhelper.isl_basic_set_constraints(islbset)
111 equalities = []
112 inequalities = []
113 for islconstraint in islconstraints:
114 constant = libisl.isl_constraint_get_constant_val(islconstraint)
115 constant = islhelper.isl_val_to_int(constant)
116 coefficients = {}
117 for index, symbol in enumerate(symbols):
118 coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint,
119 libisl.isl_dim_set, index)
120 coefficient = islhelper.isl_val_to_int(coefficient)
121 if coefficient != 0:
122 coefficients[symbol] = coefficient
123 expression = Expression(coefficients, constant)
124 if libisl.isl_constraint_is_equality(islconstraint):
125 equalities.append(expression)
126 else:
127 inequalities.append(expression)
128 libisl.isl_basic_set_free(islbset)
129 self = object().__new__(Polyhedron)
130 self._equalities = tuple(equalities)
131 self._inequalities = tuple(inequalities)
132 self._constraints = tuple(equalities + inequalities)
133 self._symbols = cls._xsymbols(self._constraints)
134 self._dimension = len(self._symbols)
135 return self
136
137 @classmethod
138 def _toislbasicset(cls, equalities, inequalities, symbols):
139 dimension = len(symbols)
140 indices = {symbol: index for index, symbol in enumerate(symbols)}
141 islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension)
142 islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp))
143 islls = libisl.isl_local_space_from_space(islsp)
144 for equality in equalities:
145 isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls))
146 for symbol, coefficient in equality.coefficients():
147 islval = str(coefficient).encode()
148 islval = libisl.isl_val_read_from_str(mainctx, islval)
149 index = indices[symbol]
150 isleq = libisl.isl_constraint_set_coefficient_val(isleq,
151 libisl.isl_dim_set, index, islval)
152 if equality.constant != 0:
153 islval = str(equality.constant).encode()
154 islval = libisl.isl_val_read_from_str(mainctx, islval)
155 isleq = libisl.isl_constraint_set_constant_val(isleq, islval)
156 islbset = libisl.isl_basic_set_add_constraint(islbset, isleq)
157 for inequality in inequalities:
158 islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls))
159 for symbol, coefficient in inequality.coefficients():
160 islval = str(coefficient).encode()
161 islval = libisl.isl_val_read_from_str(mainctx, islval)
162 index = indices[symbol]
163 islin = libisl.isl_constraint_set_coefficient_val(islin,
164 libisl.isl_dim_set, index, islval)
165 if inequality.constant != 0:
166 islval = str(inequality.constant).encode()
167 islval = libisl.isl_val_read_from_str(mainctx, islval)
168 islin = libisl.isl_constraint_set_constant_val(islin, islval)
169 islbset = libisl.isl_basic_set_add_constraint(islbset, islin)
170 return islbset
171
172 @classmethod
173 def fromstring(cls, string):
174 domain = Domain.fromstring(string)
175 if not isinstance(domain, Polyhedron):
176 raise ValueError('non-polyhedral expression: {!r}'.format(string))
177 return domain
178
179 def __repr__(self):
180 if self.isempty():
181 return 'Empty'
182 elif self.isuniverse():
183 return 'Universe'
184 else:
185 strings = []
186 for equality in self.equalities:
187 strings.append('0 == {}'.format(equality))
188 for inequality in self.inequalities:
189 strings.append('0 <= {}'.format(inequality))
190 if len(strings) == 1:
191 return strings[0]
192 else:
193 return 'And({})'.format(', '.join(strings))
194
195 @classmethod
196 def fromsympy(cls, expr):
197 domain = Domain.fromsympy(expr)
198 if not isinstance(domain, Polyhedron):
199 raise ValueError('non-polyhedral expression: {!r}'.format(expr))
200 return domain
201
202 def tosympy(self):
203 import sympy
204 constraints = []
205 for equality in self.equalities:
206 constraints.append(sympy.Eq(equality.tosympy(), 0))
207 for inequality in self.inequalities:
208 constraints.append(sympy.Ge(inequality.tosympy(), 0))
209 return sympy.And(*constraints)
210
211 @classmethod
212 def _sort_polygon_2d(cls, points):
213 if len(points) <= 3:
214 return points
215 o = sum((Vector(point) for point in points)) / len(points)
216 o = Point(o.coordinates())
217 angles = {}
218 for m in points:
219 om = Vector(o, m)
220 dx, dy = (coordinate for symbol, coordinates in om.coordinates())
221 angle = math.atan2(dy, dx)
222 angles[m] = angle
223 return sorted(points, key=angles.get)
224
225 @classmethod
226 def _sort_polygon_3d(cls, points):
227 if len(points) <= 3:
228 return points
229 o = sum((Vector(point) for point in points)) / len(points)
230 o = Point(o.coordinates())
231 a, b = points[:2]
232 oa = Vector(o, a)
233 ob = Vector(o, b)
234 norm_oa = oa.norm()
235 u = (oa.cross(ob)).asunit()
236 angles = {a: 0.}
237 for m in points[1:]:
238 om = Vector(o, m)
239 normprod = norm_oa * om.norm()
240 cosinus = oa.dot(om) / normprod
241 sinus = u.dot(oa.cross(om)) / normprod
242 angle = math.acos(cosinus)
243 angle = math.copysign(angle, sinus)
244 angles[m] = angle
245 return sorted(points, key=angles.get)
246
247 def plot(self):
248 import matplotlib.pyplot as plt
249 from matplotlib.path import Path
250 import matplotlib.patches as patches
251
252 if len(self.symbols)> 3:
253 raise TypeError
254
255 elif len(self.symbols) == 2:
256 verts = self.vertices()
257 points = []
258 codes = [Path.MOVETO]
259 for vert in verts:
260 pairs = ()
261 for sym in sorted(vert, key=Symbol.sortkey):
262 num = vert.get(sym)
263 pairs = pairs + (num,)
264 points.append(pairs)
265 points.append((0.0, 0.0))
266 num = len(points)
267 while num > 2:
268 codes.append(Path.LINETO)
269 num = num - 1
270 else:
271 codes.append(Path.CLOSEPOLY)
272 path = Path(points, codes)
273 fig = plt.figure()
274 ax = fig.add_subplot(111)
275 patch = patches.PathPatch(path, facecolor='blue', lw=2)
276 ax.add_patch(patch)
277 ax.set_xlim(-5,5)
278 ax.set_ylim(-5,5)
279 plt.show()
280
281 elif len(self.symbols)==3:
282 return 0
283
284 return points
285
286
287 def _polymorphic(func):
288 @functools.wraps(func)
289 def wrapper(left, right):
290 if isinstance(left, numbers.Rational):
291 left = Rational(left)
292 elif not isinstance(left, Expression):
293 raise TypeError('left must be a a rational number '
294 'or a linear expression')
295 if isinstance(right, numbers.Rational):
296 right = Rational(right)
297 elif not isinstance(right, Expression):
298 raise TypeError('right must be a a rational number '
299 'or a linear expression')
300 return func(left, right)
301 return wrapper
302
303 @_polymorphic
304 def Lt(left, right):
305 return Polyhedron([], [right - left - 1])
306
307 @_polymorphic
308 def Le(left, right):
309 return Polyhedron([], [right - left])
310
311 @_polymorphic
312 def Eq(left, right):
313 return Polyhedron([left - right], [])
314
315 @_polymorphic
316 def Ne(left, right):
317 return ~Eq(left, right)
318
319 @_polymorphic
320 def Gt(left, right):
321 return Polyhedron([], [left - right - 1])
322
323 @_polymorphic
324 def Ge(left, right):
325 return Polyhedron([], [left - right])
326
327
328 Empty = Eq(1, 0)
329
330 Universe = Polyhedron([])