Simplify Expression.__mul__(), Expression.__truediv__()
[linpy.git] / pypol / polyhedra.py
1 import functools
2 import math
3 import numbers
4
5 from . import islhelper
6
7 from .islhelper import mainctx, libisl
8 from .geometry import GeometricObject, Point, Vector
9 from .linexprs import Expression, Symbol, Rational
10 from .domains import Domain
11
12
13 __all__ = [
14 'Polyhedron',
15 'Lt', 'Le', 'Eq', 'Ne', 'Ge', 'Gt',
16 'Empty', 'Universe',
17 ]
18
19
20 class Polyhedron(Domain):
21
22 __slots__ = (
23 '_equalities',
24 '_inequalities',
25 '_constraints',
26 '_symbols',
27 '_dimension',
28 )
29
30 def __new__(cls, equalities=None, inequalities=None):
31 if isinstance(equalities, str):
32 if inequalities is not None:
33 raise TypeError('too many arguments')
34 return cls.fromstring(equalities)
35 elif isinstance(equalities, GeometricObject):
36 if inequalities is not None:
37 raise TypeError('too many arguments')
38 return equalities.aspolyhedron()
39 if equalities is None:
40 equalities = []
41 else:
42 for i, equality in enumerate(equalities):
43 if not isinstance(equality, Expression):
44 raise TypeError('equalities must be linear expressions')
45 equalities[i] = equality.scaleint()
46 if inequalities is None:
47 inequalities = []
48 else:
49 for i, inequality in enumerate(inequalities):
50 if not isinstance(inequality, Expression):
51 raise TypeError('inequalities must be linear expressions')
52 inequalities[i] = inequality.scaleint()
53 symbols = cls._xsymbols(equalities + inequalities)
54 islbset = cls._toislbasicset(equalities, inequalities, symbols)
55 return cls._fromislbasicset(islbset, symbols)
56
57 @property
58 def equalities(self):
59 return self._equalities
60
61 @property
62 def inequalities(self):
63 return self._inequalities
64
65 @property
66 def constraints(self):
67 return self._constraints
68
69 @property
70 def polyhedra(self):
71 return self,
72
73 def disjoint(self):
74 return self
75
76 def isuniverse(self):
77 islbset = self._toislbasicset(self.equalities, self.inequalities,
78 self.symbols)
79 universe = bool(libisl.isl_basic_set_is_universe(islbset))
80 libisl.isl_basic_set_free(islbset)
81 return universe
82
83 def aspolyhedron(self):
84 return self
85
86 def __contains__(self, point):
87 if not isinstance(point, Point):
88 raise TypeError('point must be a Point instance')
89 if self.symbols != point.symbols:
90 raise ValueError('arguments must belong to the same space')
91 for equality in self.equalities:
92 if equality.subs(point.coordinates()) != 0:
93 return False
94 for inequality in self.inequalities:
95 if inequality.subs(point.coordinates()) < 0:
96 return False
97 return True
98
99 def subs(self, symbol, expression=None):
100 equalities = [equality.subs(symbol, expression)
101 for equality in self.equalities]
102 inequalities = [inequality.subs(symbol, expression)
103 for inequality in self.inequalities]
104 return Polyhedron(equalities, inequalities)
105
106 @classmethod
107 def _fromislbasicset(cls, islbset, symbols):
108 islconstraints = islhelper.isl_basic_set_constraints(islbset)
109 equalities = []
110 inequalities = []
111 for islconstraint in islconstraints:
112 constant = libisl.isl_constraint_get_constant_val(islconstraint)
113 constant = islhelper.isl_val_to_int(constant)
114 coefficients = {}
115 for index, symbol in enumerate(symbols):
116 coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint,
117 libisl.isl_dim_set, index)
118 coefficient = islhelper.isl_val_to_int(coefficient)
119 if coefficient != 0:
120 coefficients[symbol] = coefficient
121 expression = Expression(coefficients, constant)
122 if libisl.isl_constraint_is_equality(islconstraint):
123 equalities.append(expression)
124 else:
125 inequalities.append(expression)
126 libisl.isl_basic_set_free(islbset)
127 self = object().__new__(Polyhedron)
128 self._equalities = tuple(equalities)
129 self._inequalities = tuple(inequalities)
130 self._constraints = tuple(equalities + inequalities)
131 self._symbols = cls._xsymbols(self._constraints)
132 self._dimension = len(self._symbols)
133 return self
134
135 @classmethod
136 def _toislbasicset(cls, equalities, inequalities, symbols):
137 dimension = len(symbols)
138 indices = {symbol: index for index, symbol in enumerate(symbols)}
139 islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension)
140 islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp))
141 islls = libisl.isl_local_space_from_space(islsp)
142 for equality in equalities:
143 isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls))
144 for symbol, coefficient in equality.coefficients():
145 islval = str(coefficient).encode()
146 islval = libisl.isl_val_read_from_str(mainctx, islval)
147 index = indices[symbol]
148 isleq = libisl.isl_constraint_set_coefficient_val(isleq,
149 libisl.isl_dim_set, index, islval)
150 if equality.constant != 0:
151 islval = str(equality.constant).encode()
152 islval = libisl.isl_val_read_from_str(mainctx, islval)
153 isleq = libisl.isl_constraint_set_constant_val(isleq, islval)
154 islbset = libisl.isl_basic_set_add_constraint(islbset, isleq)
155 for inequality in inequalities:
156 islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls))
157 for symbol, coefficient in inequality.coefficients():
158 islval = str(coefficient).encode()
159 islval = libisl.isl_val_read_from_str(mainctx, islval)
160 index = indices[symbol]
161 islin = libisl.isl_constraint_set_coefficient_val(islin,
162 libisl.isl_dim_set, index, islval)
163 if inequality.constant != 0:
164 islval = str(inequality.constant).encode()
165 islval = libisl.isl_val_read_from_str(mainctx, islval)
166 islin = libisl.isl_constraint_set_constant_val(islin, islval)
167 islbset = libisl.isl_basic_set_add_constraint(islbset, islin)
168 return islbset
169
170 @classmethod
171 def fromstring(cls, string):
172 domain = Domain.fromstring(string)
173 if not isinstance(domain, Polyhedron):
174 raise ValueError('non-polyhedral expression: {!r}'.format(string))
175 return domain
176
177 def __repr__(self):
178 if self.isempty():
179 return 'Empty'
180 elif self.isuniverse():
181 return 'Universe'
182 else:
183 strings = []
184 for equality in self.equalities:
185 strings.append('Eq({}, 0)'.format(equality))
186 for inequality in self.inequalities:
187 strings.append('Ge({}, 0)'.format(inequality))
188 if len(strings) == 1:
189 return strings[0]
190 else:
191 return 'And({})'.format(', '.join(strings))
192
193 def _repr_latex_(self):
194 if self.isempty():
195 return '$\\emptyset$'
196 elif self.isuniverse():
197 return '$\\Omega$'
198 else:
199 strings = []
200 for equality in self.equalities:
201 strings.append('{} = 0'.format(equality._repr_latex_().strip('$')))
202 for inequality in self.inequalities:
203 strings.append('{} \\ge 0'.format(inequality._repr_latex_().strip('$')))
204 return '${}$'.format(' \\wedge '.join(strings))
205
206 @classmethod
207 def fromsympy(cls, expr):
208 domain = Domain.fromsympy(expr)
209 if not isinstance(domain, Polyhedron):
210 raise ValueError('non-polyhedral expression: {!r}'.format(expr))
211 return domain
212
213 def tosympy(self):
214 import sympy
215 constraints = []
216 for equality in self.equalities:
217 constraints.append(sympy.Eq(equality.tosympy(), 0))
218 for inequality in self.inequalities:
219 constraints.append(sympy.Ge(inequality.tosympy(), 0))
220 return sympy.And(*constraints)
221
222 @classmethod
223 def _polygon_inner_point(cls, points):
224 symbols = points[0].symbols
225 coordinates = {symbol: 0 for symbol in symbols}
226 for point in points:
227 for symbol, coordinate in point.coordinates():
228 coordinates[symbol] += coordinate
229 for symbol in symbols:
230 coordinates[symbol] /= len(points)
231 return Point(coordinates)
232
233 @classmethod
234 def _sort_polygon_2d(cls, points):
235 if len(points) <= 3:
236 return points
237 o = cls._polygon_inner_point(points)
238 angles = {}
239 for m in points:
240 om = Vector(o, m)
241 dx, dy = (coordinate for symbol, coordinate in om.coordinates())
242 angle = math.atan2(dy, dx)
243 angles[m] = angle
244 return sorted(points, key=angles.get)
245
246 @classmethod
247 def _sort_polygon_3d(cls, points):
248 if len(points) <= 3:
249 return points
250 o = cls._polygon_inner_point(points)
251 a = points[0]
252 oa = Vector(o, a)
253 norm_oa = oa.norm()
254 for b in points[1:]:
255 ob = Vector(o, b)
256 u = oa.cross(ob)
257 if not u.isnull():
258 u = u.asunit()
259 break
260 else:
261 raise ValueError('degenerate polygon')
262 angles = {a: 0.}
263 for m in points[1:]:
264 om = Vector(o, m)
265 normprod = norm_oa * om.norm()
266 cosinus = oa.dot(om) / normprod
267 sinus = u.dot(oa.cross(om)) / normprod
268 angle = math.acos(cosinus)
269 angle = math.copysign(angle, sinus)
270 angles[m] = angle
271 return sorted(points, key=angles.get)
272
273 def faces(self):
274 vertices = self.vertices()
275 faces = []
276 for constraint in self.constraints:
277 face = []
278 for vertex in vertices:
279 if constraint.subs(vertex.coordinates()) == 0:
280 face.append(vertex)
281 faces.append(face)
282 return faces
283
284 def plot(self):
285 import matplotlib.pyplot as plt
286 from matplotlib.path import Path
287 import matplotlib.patches as patches
288
289 if len(self.symbols)> 3:
290 raise TypeError
291
292 elif len(self.symbols) == 2:
293 verts = self.vertices()
294 points = []
295 codes = [Path.MOVETO]
296 for vert in verts:
297 pairs = ()
298 for sym in sorted(vert, key=Symbol.sortkey):
299 num = vert.get(sym)
300 pairs = pairs + (num,)
301 points.append(pairs)
302 points.append((0.0, 0.0))
303 num = len(points)
304 while num > 2:
305 codes.append(Path.LINETO)
306 num = num - 1
307 else:
308 codes.append(Path.CLOSEPOLY)
309 path = Path(points, codes)
310 fig = plt.figure()
311 ax = fig.add_subplot(111)
312 patch = patches.PathPatch(path, facecolor='blue', lw=2)
313 ax.add_patch(patch)
314 ax.set_xlim(-5,5)
315 ax.set_ylim(-5,5)
316 plt.show()
317
318 elif len(self.symbols)==3:
319 return 0
320
321 return points
322
323
324 def _polymorphic(func):
325 @functools.wraps(func)
326 def wrapper(left, right):
327 if not isinstance(left, Expression):
328 if isinstance(left, numbers.Rational):
329 left = Rational(left)
330 else:
331 raise TypeError('left must be a a rational number '
332 'or a linear expression')
333 if not isinstance(right, Expression):
334 if isinstance(right, numbers.Rational):
335 right = Rational(right)
336 else:
337 raise TypeError('right must be a a rational number '
338 'or a linear expression')
339 return func(left, right)
340 return wrapper
341
342 @_polymorphic
343 def Lt(left, right):
344 return Polyhedron([], [right - left - 1])
345
346 @_polymorphic
347 def Le(left, right):
348 return Polyhedron([], [right - left])
349
350 @_polymorphic
351 def Eq(left, right):
352 return Polyhedron([left - right], [])
353
354 @_polymorphic
355 def Ne(left, right):
356 return ~Eq(left, right)
357
358 @_polymorphic
359 def Gt(left, right):
360 return Polyhedron([], [left - right - 1])
361
362 @_polymorphic
363 def Ge(left, right):
364 return Polyhedron([], [left - right])
365
366
367 Empty = Eq(1, 0)
368
369 Universe = Polyhedron([])