Always set xlim, ylim, zlim in plot functions
[linpy.git] / pypol / polyhedra.py
1 import functools
2 import math
3 import numbers
4
5 from . import islhelper
6
7 from .islhelper import mainctx, libisl
8 from .geometry import GeometricObject, Point, Vector
9 from .linexprs import Expression, Symbol, Rational
10 from .domains import Domain
11
12
13 __all__ = [
14 'Polyhedron',
15 'Lt', 'Le', 'Eq', 'Ne', 'Ge', 'Gt',
16 'Empty', 'Universe',
17 ]
18
19
20 class Polyhedron(Domain):
21
22 __slots__ = (
23 '_equalities',
24 '_inequalities',
25 '_constraints',
26 '_symbols',
27 '_dimension',
28 )
29
30 def __new__(cls, equalities=None, inequalities=None):
31 if isinstance(equalities, str):
32 if inequalities is not None:
33 raise TypeError('too many arguments')
34 return cls.fromstring(equalities)
35 elif isinstance(equalities, GeometricObject):
36 if inequalities is not None:
37 raise TypeError('too many arguments')
38 return equalities.aspolyhedron()
39 if equalities is None:
40 equalities = []
41 else:
42 for i, equality in enumerate(equalities):
43 if not isinstance(equality, Expression):
44 raise TypeError('equalities must be linear expressions')
45 equalities[i] = equality.scaleint()
46 if inequalities is None:
47 inequalities = []
48 else:
49 for i, inequality in enumerate(inequalities):
50 if not isinstance(inequality, Expression):
51 raise TypeError('inequalities must be linear expressions')
52 inequalities[i] = inequality.scaleint()
53 symbols = cls._xsymbols(equalities + inequalities)
54 islbset = cls._toislbasicset(equalities, inequalities, symbols)
55 return cls._fromislbasicset(islbset, symbols)
56
57 @property
58 def equalities(self):
59 return self._equalities
60
61 @property
62 def inequalities(self):
63 return self._inequalities
64
65 @property
66 def constraints(self):
67 return self._constraints
68
69 @property
70 def polyhedra(self):
71 return self,
72
73 def disjoint(self):
74 """
75 Return this set as disjoint.
76 """
77 return self
78
79 def isuniverse(self):
80 """
81 Return true if this set is the Universe set.
82 """
83 islbset = self._toislbasicset(self.equalities, self.inequalities,
84 self.symbols)
85 universe = bool(libisl.isl_basic_set_is_universe(islbset))
86 libisl.isl_basic_set_free(islbset)
87 return universe
88
89 def aspolyhedron(self):
90 """
91 Return polyhedral hull of this set.
92 """
93 return self
94
95 def __contains__(self, point):
96 if not isinstance(point, Point):
97 raise TypeError('point must be a Point instance')
98 if self.symbols != point.symbols:
99 raise ValueError('arguments must belong to the same space')
100 for equality in self.equalities:
101 if equality.subs(point.coordinates()) != 0:
102 return False
103 for inequality in self.inequalities:
104 if inequality.subs(point.coordinates()) < 0:
105 return False
106 return True
107
108 def subs(self, symbol, expression=None):
109 equalities = [equality.subs(symbol, expression)
110 for equality in self.equalities]
111 inequalities = [inequality.subs(symbol, expression)
112 for inequality in self.inequalities]
113 return Polyhedron(equalities, inequalities)
114
115 @classmethod
116 def _fromislbasicset(cls, islbset, symbols):
117 islconstraints = islhelper.isl_basic_set_constraints(islbset)
118 equalities = []
119 inequalities = []
120 for islconstraint in islconstraints:
121 constant = libisl.isl_constraint_get_constant_val(islconstraint)
122 constant = islhelper.isl_val_to_int(constant)
123 coefficients = {}
124 for index, symbol in enumerate(symbols):
125 coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint,
126 libisl.isl_dim_set, index)
127 coefficient = islhelper.isl_val_to_int(coefficient)
128 if coefficient != 0:
129 coefficients[symbol] = coefficient
130 expression = Expression(coefficients, constant)
131 if libisl.isl_constraint_is_equality(islconstraint):
132 equalities.append(expression)
133 else:
134 inequalities.append(expression)
135 libisl.isl_basic_set_free(islbset)
136 self = object().__new__(Polyhedron)
137 self._equalities = tuple(equalities)
138 self._inequalities = tuple(inequalities)
139 self._constraints = tuple(equalities + inequalities)
140 self._symbols = cls._xsymbols(self._constraints)
141 self._dimension = len(self._symbols)
142 return self
143
144 @classmethod
145 def _toislbasicset(cls, equalities, inequalities, symbols):
146 dimension = len(symbols)
147 indices = {symbol: index for index, symbol in enumerate(symbols)}
148 islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension)
149 islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp))
150 islls = libisl.isl_local_space_from_space(islsp)
151 for equality in equalities:
152 isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls))
153 for symbol, coefficient in equality.coefficients():
154 islval = str(coefficient).encode()
155 islval = libisl.isl_val_read_from_str(mainctx, islval)
156 index = indices[symbol]
157 isleq = libisl.isl_constraint_set_coefficient_val(isleq,
158 libisl.isl_dim_set, index, islval)
159 if equality.constant != 0:
160 islval = str(equality.constant).encode()
161 islval = libisl.isl_val_read_from_str(mainctx, islval)
162 isleq = libisl.isl_constraint_set_constant_val(isleq, islval)
163 islbset = libisl.isl_basic_set_add_constraint(islbset, isleq)
164 for inequality in inequalities:
165 islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls))
166 for symbol, coefficient in inequality.coefficients():
167 islval = str(coefficient).encode()
168 islval = libisl.isl_val_read_from_str(mainctx, islval)
169 index = indices[symbol]
170 islin = libisl.isl_constraint_set_coefficient_val(islin,
171 libisl.isl_dim_set, index, islval)
172 if inequality.constant != 0:
173 islval = str(inequality.constant).encode()
174 islval = libisl.isl_val_read_from_str(mainctx, islval)
175 islin = libisl.isl_constraint_set_constant_val(islin, islval)
176 islbset = libisl.isl_basic_set_add_constraint(islbset, islin)
177 return islbset
178
179 @classmethod
180 def fromstring(cls, string):
181 domain = Domain.fromstring(string)
182 if not isinstance(domain, Polyhedron):
183 raise ValueError('non-polyhedral expression: {!r}'.format(string))
184 return domain
185
186 def __repr__(self):
187 if self.isempty():
188 return 'Empty'
189 elif self.isuniverse():
190 return 'Universe'
191 else:
192 strings = []
193 for equality in self.equalities:
194 strings.append('Eq({}, 0)'.format(equality))
195 for inequality in self.inequalities:
196 strings.append('Ge({}, 0)'.format(inequality))
197 if len(strings) == 1:
198 return strings[0]
199 else:
200 return 'And({})'.format(', '.join(strings))
201
202 def _repr_latex_(self):
203 if self.isempty():
204 return '$\\emptyset$'
205 elif self.isuniverse():
206 return '$\\Omega$'
207 else:
208 strings = []
209 for equality in self.equalities:
210 strings.append('{} = 0'.format(equality._repr_latex_().strip('$')))
211 for inequality in self.inequalities:
212 strings.append('{} \\ge 0'.format(inequality._repr_latex_().strip('$')))
213 return '${}$'.format(' \\wedge '.join(strings))
214
215 @classmethod
216 def fromsympy(cls, expr):
217 domain = Domain.fromsympy(expr)
218 if not isinstance(domain, Polyhedron):
219 raise ValueError('non-polyhedral expression: {!r}'.format(expr))
220 return domain
221
222 def tosympy(self):
223 import sympy
224 constraints = []
225 for equality in self.equalities:
226 constraints.append(sympy.Eq(equality.tosympy(), 0))
227 for inequality in self.inequalities:
228 constraints.append(sympy.Ge(inequality.tosympy(), 0))
229 return sympy.And(*constraints)
230
231 @classmethod
232 def _polygon_inner_point(cls, points):
233 symbols = points[0].symbols
234 coordinates = {symbol: 0 for symbol in symbols}
235 for point in points:
236 for symbol, coordinate in point.coordinates():
237 coordinates[symbol] += coordinate
238 for symbol in symbols:
239 coordinates[symbol] /= len(points)
240 return Point(coordinates)
241
242 @classmethod
243 def _sort_polygon_2d(cls, points):
244 if len(points) <= 3:
245 return points
246 o = cls._polygon_inner_point(points)
247 angles = {}
248 for m in points:
249 om = Vector(o, m)
250 dx, dy = (coordinate for symbol, coordinate in om.coordinates())
251 angle = math.atan2(dy, dx)
252 angles[m] = angle
253 return sorted(points, key=angles.get)
254
255 @classmethod
256 def _sort_polygon_3d(cls, points):
257 if len(points) <= 3:
258 return points
259 o = cls._polygon_inner_point(points)
260 a = points[0]
261 oa = Vector(o, a)
262 norm_oa = oa.norm()
263 for b in points[1:]:
264 ob = Vector(o, b)
265 u = oa.cross(ob)
266 if not u.isnull():
267 u = u.asunit()
268 break
269 else:
270 raise ValueError('degenerate polygon')
271 angles = {a: 0.}
272 for m in points[1:]:
273 om = Vector(o, m)
274 normprod = norm_oa * om.norm()
275 cosinus = max(oa.dot(om) / normprod, -1.)
276 sinus = u.dot(oa.cross(om)) / normprod
277 angle = math.acos(cosinus)
278 angle = math.copysign(angle, sinus)
279 angles[m] = angle
280 return sorted(points, key=angles.get)
281
282 def faces(self):
283 vertices = self.vertices()
284 faces = []
285 for constraint in self.constraints:
286 face = []
287 for vertex in vertices:
288 if constraint.subs(vertex.coordinates()) == 0:
289 face.append(vertex)
290 faces.append(face)
291 return faces
292
293 def _plot_2d(self, plot=None, **kwargs):
294 import matplotlib.pyplot as plt
295 from matplotlib.patches import Polygon
296 vertices = self._sort_polygon_2d(self.vertices())
297 xys = [tuple(vertex.values()) for vertex in vertices]
298 if plot is None:
299 fig = plt.figure()
300 plot = fig.add_subplot(1, 1, 1)
301 xmin, xmax = plot.get_xlim()
302 ymin, ymax = plot.get_xlim()
303 xs, ys = zip(*xys)
304 xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
305 ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
306 plot.set_xlim(xmin, xmax)
307 plot.set_ylim(ymin, ymax)
308 plot.add_patch(Polygon(xys, closed=True, **kwargs))
309 return plot
310
311 def _plot_3d(self, plot=None, **kwargs):
312 import matplotlib.pyplot as plt
313 from mpl_toolkits.mplot3d import Axes3D
314 from mpl_toolkits.mplot3d.art3d import Poly3DCollection
315 if plot is None:
316 fig = plt.figure()
317 axes = Axes3D(fig)
318 else:
319 axes = plot
320 xmin, xmax = axes.get_xlim()
321 ymin, ymax = axes.get_xlim()
322 zmin, zmax = axes.get_xlim()
323 poly_xyzs = []
324 for vertices in self.faces():
325 if len(vertices) == 0:
326 continue
327 vertices = Polyhedron._sort_polygon_3d(vertices)
328 vertices.append(vertices[0])
329 face_xyzs = [tuple(vertex.values()) for vertex in vertices]
330 xs, ys, zs = zip(*face_xyzs)
331 xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
332 ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
333 zmin, zmax = min(zmin, float(min(zs))), max(zmax, float(max(zs)))
334 poly_xyzs.append(face_xyzs)
335 collection = Poly3DCollection(poly_xyzs, **kwargs)
336 axes.add_collection3d(collection)
337 axes.set_xlim(xmin, xmax)
338 axes.set_ylim(ymin, ymax)
339 axes.set_zlim(zmin, zmax)
340 return axes
341
342 def plot(self, plot=None, **kwargs):
343 """
344 Display 3D plot of set.
345 """
346 if self.dimension == 2:
347 return self._plot_2d(plot=plot, **kwargs)
348 elif self.dimension == 3:
349 return self._plot_3d(plot=plot, **kwargs)
350 else:
351 raise ValueError('polyhedron must be 2 or 3-dimensional')
352
353
354 def _polymorphic(func):
355 @functools.wraps(func)
356 def wrapper(left, right):
357 if not isinstance(left, Expression):
358 if isinstance(left, numbers.Rational):
359 left = Rational(left)
360 else:
361 raise TypeError('left must be a a rational number '
362 'or a linear expression')
363 if not isinstance(right, Expression):
364 if isinstance(right, numbers.Rational):
365 right = Rational(right)
366 else:
367 raise TypeError('right must be a a rational number '
368 'or a linear expression')
369 return func(left, right)
370 return wrapper
371
372 @_polymorphic
373 def Lt(left, right):
374 """
375 Return true if the first set is less than the second.
376 """
377 return Polyhedron([], [right - left - 1])
378
379 @_polymorphic
380 def Le(left, right):
381 """
382 Return true the first set is less than or equal to the second.
383 """
384 return Polyhedron([], [right - left])
385
386 @_polymorphic
387 def Eq(left, right):
388 """
389 Return true if the sets are equal.
390 """
391 return Polyhedron([left - right], [])
392
393 @_polymorphic
394 def Ne(left, right):
395 """
396 Return true if the sets are NOT equal.
397 """
398 return ~Eq(left, right)
399
400 @_polymorphic
401 def Gt(left, right):
402 """
403 Return true if the first set is greater than the second set.
404 """
405 return Polyhedron([], [left - right - 1])
406
407 @_polymorphic
408 def Ge(left, right):
409 """
410 Return true if the first set is greater than or equal the second set.
411 """
412 return Polyhedron([], [left - right])
413
414
415 Empty = Eq(1, 0)
416
417 Universe = Polyhedron([])