Small implementation improvement in linexprs.py
[linpy.git] / pypol / polyhedra.py
1 import functools
2 import math
3 import numbers
4
5 from . import islhelper
6
7 from .islhelper import mainctx, libisl
8 from .geometry import GeometricObject, Point, Vector
9 from .linexprs import Expression, Symbol, Rational
10 from .domains import Domain
11
12
13 __all__ = [
14 'Polyhedron',
15 'Lt', 'Le', 'Eq', 'Ne', 'Ge', 'Gt',
16 'Empty', 'Universe',
17 ]
18
19
20 class Polyhedron(Domain):
21
22 __slots__ = (
23 '_equalities',
24 '_inequalities',
25 '_constraints',
26 '_symbols',
27 '_dimension',
28 )
29
30 def __new__(cls, equalities=None, inequalities=None):
31 if isinstance(equalities, str):
32 if inequalities is not None:
33 raise TypeError('too many arguments')
34 return cls.fromstring(equalities)
35 elif isinstance(equalities, GeometricObject):
36 if inequalities is not None:
37 raise TypeError('too many arguments')
38 return equalities.aspolyhedron()
39 if equalities is None:
40 equalities = []
41 else:
42 for i, equality in enumerate(equalities):
43 if not isinstance(equality, Expression):
44 raise TypeError('equalities must be linear expressions')
45 equalities[i] = equality.scaleint()
46 if inequalities is None:
47 inequalities = []
48 else:
49 for i, inequality in enumerate(inequalities):
50 if not isinstance(inequality, Expression):
51 raise TypeError('inequalities must be linear expressions')
52 inequalities[i] = inequality.scaleint()
53 symbols = cls._xsymbols(equalities + inequalities)
54 islbset = cls._toislbasicset(equalities, inequalities, symbols)
55 return cls._fromislbasicset(islbset, symbols)
56
57 @property
58 def equalities(self):
59 return self._equalities
60
61 @property
62 def inequalities(self):
63 return self._inequalities
64
65 @property
66 def constraints(self):
67 return self._constraints
68
69 @property
70 def polyhedra(self):
71 return self,
72
73 def disjoint(self):
74 """
75 Return this set as disjoint.
76 """
77 return self
78
79 def isuniverse(self):
80 """
81 Return true if this set is the Universe set.
82 """
83 islbset = self._toislbasicset(self.equalities, self.inequalities,
84 self.symbols)
85 universe = bool(libisl.isl_basic_set_is_universe(islbset))
86 libisl.isl_basic_set_free(islbset)
87 return universe
88
89 def aspolyhedron(self):
90 """
91 Return polyhedral hull of this set.
92 """
93 return self
94
95 def __contains__(self, point):
96 if not isinstance(point, Point):
97 raise TypeError('point must be a Point instance')
98 if self.symbols != point.symbols:
99 raise ValueError('arguments must belong to the same space')
100 for equality in self.equalities:
101 if equality.subs(point.coordinates()) != 0:
102 return False
103 for inequality in self.inequalities:
104 if inequality.subs(point.coordinates()) < 0:
105 return False
106 return True
107
108 def subs(self, symbol, expression=None):
109 equalities = [equality.subs(symbol, expression)
110 for equality in self.equalities]
111 inequalities = [inequality.subs(symbol, expression)
112 for inequality in self.inequalities]
113 return Polyhedron(equalities, inequalities)
114
115 @classmethod
116 def _fromislbasicset(cls, islbset, symbols):
117 islconstraints = islhelper.isl_basic_set_constraints(islbset)
118 equalities = []
119 inequalities = []
120 for islconstraint in islconstraints:
121 constant = libisl.isl_constraint_get_constant_val(islconstraint)
122 constant = islhelper.isl_val_to_int(constant)
123 coefficients = {}
124 for index, symbol in enumerate(symbols):
125 coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint,
126 libisl.isl_dim_set, index)
127 coefficient = islhelper.isl_val_to_int(coefficient)
128 if coefficient != 0:
129 coefficients[symbol] = coefficient
130 expression = Expression(coefficients, constant)
131 if libisl.isl_constraint_is_equality(islconstraint):
132 equalities.append(expression)
133 else:
134 inequalities.append(expression)
135 libisl.isl_basic_set_free(islbset)
136 self = object().__new__(Polyhedron)
137 self._equalities = tuple(equalities)
138 self._inequalities = tuple(inequalities)
139 self._constraints = tuple(equalities + inequalities)
140 self._symbols = cls._xsymbols(self._constraints)
141 self._dimension = len(self._symbols)
142 return self
143
144 @classmethod
145 def _toislbasicset(cls, equalities, inequalities, symbols):
146 dimension = len(symbols)
147 indices = {symbol: index for index, symbol in enumerate(symbols)}
148 islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension)
149 islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp))
150 islls = libisl.isl_local_space_from_space(islsp)
151 for equality in equalities:
152 isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls))
153 for symbol, coefficient in equality.coefficients():
154 islval = str(coefficient).encode()
155 islval = libisl.isl_val_read_from_str(mainctx, islval)
156 index = indices[symbol]
157 isleq = libisl.isl_constraint_set_coefficient_val(isleq,
158 libisl.isl_dim_set, index, islval)
159 if equality.constant != 0:
160 islval = str(equality.constant).encode()
161 islval = libisl.isl_val_read_from_str(mainctx, islval)
162 isleq = libisl.isl_constraint_set_constant_val(isleq, islval)
163 islbset = libisl.isl_basic_set_add_constraint(islbset, isleq)
164 for inequality in inequalities:
165 islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls))
166 for symbol, coefficient in inequality.coefficients():
167 islval = str(coefficient).encode()
168 islval = libisl.isl_val_read_from_str(mainctx, islval)
169 index = indices[symbol]
170 islin = libisl.isl_constraint_set_coefficient_val(islin,
171 libisl.isl_dim_set, index, islval)
172 if inequality.constant != 0:
173 islval = str(inequality.constant).encode()
174 islval = libisl.isl_val_read_from_str(mainctx, islval)
175 islin = libisl.isl_constraint_set_constant_val(islin, islval)
176 islbset = libisl.isl_basic_set_add_constraint(islbset, islin)
177 return islbset
178
179 @classmethod
180 def fromstring(cls, string):
181 domain = Domain.fromstring(string)
182 if not isinstance(domain, Polyhedron):
183 raise ValueError('non-polyhedral expression: {!r}'.format(string))
184 return domain
185
186 def __repr__(self):
187 if self.isempty():
188 return 'Empty'
189 elif self.isuniverse():
190 return 'Universe'
191 else:
192 strings = []
193 for equality in self.equalities:
194 strings.append('Eq({}, 0)'.format(equality))
195 for inequality in self.inequalities:
196 strings.append('Ge({}, 0)'.format(inequality))
197 if len(strings) == 1:
198 return strings[0]
199 else:
200 return 'And({})'.format(', '.join(strings))
201
202 def _repr_latex_(self):
203 if self.isempty():
204 return '$\\emptyset$'
205 elif self.isuniverse():
206 return '$\\Omega$'
207 else:
208 strings = []
209 for equality in self.equalities:
210 strings.append('{} = 0'.format(equality._repr_latex_().strip('$')))
211 for inequality in self.inequalities:
212 strings.append('{} \\ge 0'.format(inequality._repr_latex_().strip('$')))
213 return '${}$'.format(' \\wedge '.join(strings))
214
215 @classmethod
216 def fromsympy(cls, expr):
217 domain = Domain.fromsympy(expr)
218 if not isinstance(domain, Polyhedron):
219 raise ValueError('non-polyhedral expression: {!r}'.format(expr))
220 return domain
221
222 def tosympy(self):
223 import sympy
224 constraints = []
225 for equality in self.equalities:
226 constraints.append(sympy.Eq(equality.tosympy(), 0))
227 for inequality in self.inequalities:
228 constraints.append(sympy.Ge(inequality.tosympy(), 0))
229 return sympy.And(*constraints)
230
231 @classmethod
232 def _polygon_inner_point(cls, points):
233 symbols = points[0].symbols
234 coordinates = {symbol: 0 for symbol in symbols}
235 for point in points:
236 for symbol, coordinate in point.coordinates():
237 coordinates[symbol] += coordinate
238 for symbol in symbols:
239 coordinates[symbol] /= len(points)
240 return Point(coordinates)
241
242 @classmethod
243 def _sort_polygon_2d(cls, points):
244 if len(points) <= 3:
245 return points
246 o = cls._polygon_inner_point(points)
247 angles = {}
248 for m in points:
249 om = Vector(o, m)
250 dx, dy = (coordinate for symbol, coordinate in om.coordinates())
251 angle = math.atan2(dy, dx)
252 angles[m] = angle
253 return sorted(points, key=angles.get)
254
255 @classmethod
256 def _sort_polygon_3d(cls, points):
257 if len(points) <= 3:
258 return points
259 o = cls._polygon_inner_point(points)
260 a = points[0]
261 oa = Vector(o, a)
262 norm_oa = oa.norm()
263 for b in points[1:]:
264 ob = Vector(o, b)
265 u = oa.cross(ob)
266 if not u.isnull():
267 u = u.asunit()
268 break
269 else:
270 raise ValueError('degenerate polygon')
271 angles = {a: 0.}
272 for m in points[1:]:
273 om = Vector(o, m)
274 normprod = norm_oa * om.norm()
275 cosinus = max(oa.dot(om) / normprod, -1.)
276 sinus = u.dot(oa.cross(om)) / normprod
277 angle = math.acos(cosinus)
278 angle = math.copysign(angle, sinus)
279 angles[m] = angle
280 return sorted(points, key=angles.get)
281
282 def faces(self):
283 vertices = self.vertices()
284 faces = []
285 for constraint in self.constraints:
286 face = []
287 for vertex in vertices:
288 if constraint.subs(vertex.coordinates()) == 0:
289 face.append(vertex)
290 faces.append(face)
291 return faces
292
293 def plot(self):
294 """
295 Display 3D plot of set.
296 """
297 import matplotlib.pyplot as plt
298 import matplotlib.patches as patches
299
300 if len(self.symbols)> 3:
301 raise TypeError
302
303 elif len(self.symbols) == 2:
304 import pylab
305 points = []
306 for verts in self.vertices():
307 pairs=()
308 for coordinate, point in verts.coordinates():
309 pairs = pairs + (float(point),)
310 points.append(pairs)
311 cent=(sum([p[0] for p in points])/len(points),sum([p[1] for p in points])/len(points))
312 points.sort(key=lambda p: math.atan2(p[1]-cent[1],p[0]-cent[0]))
313 pylab.scatter([p[0] for p in points],[p[1] for p in points])
314 pylab.gca().add_patch(patches.Polygon(points,closed=True,fill=True))
315 pylab.grid()
316 pylab.show()
317
318 elif len(self.symbols)==3:
319 from mpl_toolkits.mplot3d import Axes3D
320 from mpl_toolkits.mplot3d.art3d import Poly3DCollection
321 faces = self.faces()
322 fig = plt.figure()
323 ax = Axes3D(fig)
324 for face in faces:
325 points = []
326 vertices = Polyhedron._sort_polygon_3d(face)
327 for verts in vertices:
328 pairs=()
329 for coordinate, point in verts.coordinates():
330 pairs = pairs + (float(point),)
331 points.append(pairs)
332 collection = Poly3DCollection([points], alpha=0.7)
333 face_color = [0.5, 0.5, 1] # alternative: matplotlib.colors.rgb2hex([0.5, 0.5, 1])
334 collection.set_facecolor(face_color)
335 ax.add_collection3d(collection)
336 ax.set_xlabel('X')
337 ax.set_xlim(0, 5)
338 ax.set_ylabel('Y')
339 ax.set_ylim(0, 5)
340 ax.set_zlabel('Z')
341 ax.set_zlim(0, 5)
342 plt.grid()
343 plt.show()
344 return points
345
346 @classmethod
347 def limit(cls, faces, variable, lim):
348 sym = []
349 if variable is 'x':
350 n = 0
351 elif variable is 'y':
352 n = 1
353 elif variable is 'z':
354 n = 2
355 for face in faces:
356 for vert in face:
357 coordinates = vert.coordinates()
358 for point in enumerate(coordinates):
359 coordinates.get(n)
360 sym.append(points)
361 if lim == 0:
362 value = min(sym)
363 else:
364 value = max(sym)
365 return value
366
367 def _polymorphic(func):
368 @functools.wraps(func)
369 def wrapper(left, right):
370 if not isinstance(left, Expression):
371 if isinstance(left, numbers.Rational):
372 left = Rational(left)
373 else:
374 raise TypeError('left must be a a rational number '
375 'or a linear expression')
376 if not isinstance(right, Expression):
377 if isinstance(right, numbers.Rational):
378 right = Rational(right)
379 else:
380 raise TypeError('right must be a a rational number '
381 'or a linear expression')
382 return func(left, right)
383 return wrapper
384
385 @_polymorphic
386 def Lt(left, right):
387 """
388 Return true if the first set is less than the second.
389 """
390 return Polyhedron([], [right - left - 1])
391
392 @_polymorphic
393 def Le(left, right):
394 """
395 Return true the first set is less than or equal to the second.
396 """
397 return Polyhedron([], [right - left])
398
399 @_polymorphic
400 def Eq(left, right):
401 """
402 Return true if the sets are equal.
403 """
404 return Polyhedron([left - right], [])
405
406 @_polymorphic
407 def Ne(left, right):
408 """
409 Return true if the sets are NOT equal.
410 """
411 return ~Eq(left, right)
412
413 @_polymorphic
414 def Gt(left, right):
415 """
416 Return true if the first set is greater than the second set.
417 """
418 return Polyhedron([], [left - right - 1])
419
420 @_polymorphic
421 def Ge(left, right):
422 """
423 Return true if the first set is greater than or equal the second set.
424 """
425 return Polyhedron([], [left - right])
426
427
428 Empty = Eq(1, 0)
429
430 Universe = Polyhedron([])