cd118e8c8dd57b7c93a97dfaa376f5c2f99c138f
[linpy.git] / pypol / domains.py
1 import ast
2 import functools
3 import re
4
5 from fractions import Fraction
6
7 from . import islhelper
8 from .islhelper import mainctx, libisl, isl_set_basic_sets
9 from .geometry import GeometricObject
10 from .coordinates import Point
11 from .linexprs import Expression, Symbol
12
13
14 __all__ = [
15 'Domain',
16 'And', 'Or', 'Not',
17 ]
18
19
20 @functools.total_ordering
21 class Domain(GeometricObject):
22
23 __slots__ = (
24 '_polyhedra',
25 '_symbols',
26 '_dimension',
27 )
28
29 def __new__(cls, *polyhedra):
30 from .polyhedra import Polyhedron
31 if len(polyhedra) == 1:
32 argument = polyhedra[0]
33 if isinstance(argument, str):
34 return cls.fromstring(argument)
35 elif isinstance(argument, GeometricObject):
36 return argument.aspolyhedron()
37 else:
38 raise TypeError('argument must be a string '
39 'or a GeometricObject instance')
40 else:
41 for polyhedron in polyhedra:
42 if not isinstance(polyhedron, Polyhedron):
43 raise TypeError('arguments must be Polyhedron instances')
44 symbols = cls._xsymbols(polyhedra)
45 islset = cls._toislset(polyhedra, symbols)
46 return cls._fromislset(islset, symbols)
47
48 @classmethod
49 def _xsymbols(cls, iterator):
50 """
51 Return the ordered tuple of symbols present in iterator.
52 """
53 symbols = set()
54 for item in iterator:
55 symbols.update(item.symbols)
56 return tuple(sorted(symbols, key=Symbol.sortkey))
57
58 @property
59 def polyhedra(self):
60 return self._polyhedra
61
62 @property
63 def symbols(self):
64 return self._symbols
65
66 @property
67 def dimension(self):
68 return self._dimension
69
70 def disjoint(self):
71 islset = self._toislset(self.polyhedra, self.symbols)
72 islset = libisl.isl_set_make_disjoint(mainctx, islset)
73 return self._fromislset(islset, self.symbols)
74
75 def isempty(self):
76 islset = self._toislset(self.polyhedra, self.symbols)
77 empty = bool(libisl.isl_set_is_empty(islset))
78 libisl.isl_set_free(islset)
79 return empty
80
81 def __bool__(self):
82 return not self.isempty()
83
84 def isuniverse(self):
85 islset = self._toislset(self.polyhedra, self.symbols)
86 universe = bool(libisl.isl_set_plain_is_universe(islset))
87 libisl.isl_set_free(islset)
88 return universe
89
90 def isbounded(self):
91 islset = self._toislset(self.polyhedra, self.symbols)
92 bounded = bool(libisl.isl_set_is_bounded(islset))
93 libisl.isl_set_free(islset)
94 return bounded
95
96 def __eq__(self, other):
97 symbols = self._xsymbols([self, other])
98 islset1 = self._toislset(self.polyhedra, symbols)
99 islset2 = other._toislset(other.polyhedra, symbols)
100 equal = bool(libisl.isl_set_is_equal(islset1, islset2))
101 libisl.isl_set_free(islset1)
102 libisl.isl_set_free(islset2)
103 return equal
104
105 def isdisjoint(self, other):
106 symbols = self._xsymbols([self, other])
107 islset1 = self._toislset(self.polyhedra, symbols)
108 islset2 = self._toislset(other.polyhedra, symbols)
109 equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
110 libisl.isl_set_free(islset1)
111 libisl.isl_set_free(islset2)
112 return equal
113
114 def issubset(self, other):
115 symbols = self._xsymbols([self, other])
116 islset1 = self._toislset(self.polyhedra, symbols)
117 islset2 = self._toislset(other.polyhedra, symbols)
118 equal = bool(libisl.isl_set_is_subset(islset1, islset2))
119 libisl.isl_set_free(islset1)
120 libisl.isl_set_free(islset2)
121 return equal
122
123 def __le__(self, other):
124 return self.issubset(other)
125
126 def __lt__(self, other):
127 symbols = self._xsymbols([self, other])
128 islset1 = self._toislset(self.polyhedra, symbols)
129 islset2 = self._toislset(other.polyhedra, symbols)
130 equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
131 libisl.isl_set_free(islset1)
132 libisl.isl_set_free(islset2)
133 return equal
134
135 def complement(self):
136 islset = self._toislset(self.polyhedra, self.symbols)
137 islset = libisl.isl_set_complement(islset)
138 return self._fromislset(islset, self.symbols)
139
140 def __invert__(self):
141 return self.complement()
142
143 def simplify(self):
144 #does not change anything in any of the examples
145 #isl seems to do this naturally
146 islset = self._toislset(self.polyhedra, self.symbols)
147 islset = libisl.isl_set_remove_redundancies(islset)
148 return self._fromislset(islset, self.symbols)
149
150 def aspolyhedron(self):
151 # several types of hull are available
152 # polyhedral seems to be the more appropriate, to be checked
153 from .polyhedra import Polyhedron
154 islset = self._toislset(self.polyhedra, self.symbols)
155 islbset = libisl.isl_set_polyhedral_hull(islset)
156 return Polyhedron._fromislbasicset(islbset, self.symbols)
157
158 def asdomain(self):
159 return self
160
161 def project(self, dims):
162 # use to remove certain variables
163 islset = self._toislset(self.polyhedra, self.symbols)
164 n = 0
165 for index, symbol in reversed(list(enumerate(self.symbols))):
166 if symbol in dims:
167 n += 1
168 elif n > 0:
169 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n)
170 n = 0
171 if n > 0:
172 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n)
173 dims = [symbol for symbol in self.symbols if symbol not in dims]
174 return Domain._fromislset(islset, dims)
175
176 def sample(self):
177 islset = self._toislset(self.polyhedra, self.symbols)
178 islpoint = libisl.isl_set_sample_point(islset)
179 if bool(libisl.isl_point_is_void(islpoint)):
180 libisl.isl_point_free(islpoint)
181 raise ValueError('domain must be non-empty')
182 point = {}
183 for index, symbol in enumerate(self.symbols):
184 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
185 libisl.isl_dim_set, index)
186 coordinate = islhelper.isl_val_to_int(coordinate)
187 point[symbol] = coordinate
188 libisl.isl_point_free(islpoint)
189 return point
190
191 def intersection(self, *others):
192 if len(others) == 0:
193 return self
194 symbols = self._xsymbols((self,) + others)
195 islset1 = self._toislset(self.polyhedra, symbols)
196 for other in others:
197 islset2 = other._toislset(other.polyhedra, symbols)
198 islset1 = libisl.isl_set_intersect(islset1, islset2)
199 return self._fromislset(islset1, symbols)
200
201 def __and__(self, other):
202 return self.intersection(other)
203
204 def union(self, *others):
205 if len(others) == 0:
206 return self
207 symbols = self._xsymbols((self,) + others)
208 islset1 = self._toislset(self.polyhedra, symbols)
209 for other in others:
210 islset2 = other._toislset(other.polyhedra, symbols)
211 islset1 = libisl.isl_set_union(islset1, islset2)
212 return self._fromislset(islset1, symbols)
213
214 def __or__(self, other):
215 return self.union(other)
216
217 def __add__(self, other):
218 return self.union(other)
219
220 def difference(self, other):
221 symbols = self._xsymbols([self, other])
222 islset1 = self._toislset(self.polyhedra, symbols)
223 islset2 = other._toislset(other.polyhedra, symbols)
224 islset = libisl.isl_set_subtract(islset1, islset2)
225 return self._fromislset(islset, symbols)
226
227 def __sub__(self, other):
228 return self.difference(other)
229
230 def lexmin(self):
231 islset = self._toislset(self.polyhedra, self.symbols)
232 islset = libisl.isl_set_lexmin(islset)
233 return self._fromislset(islset, self.symbols)
234
235 def lexmax(self):
236 islset = self._toislset(self.polyhedra, self.symbols)
237 islset = libisl.isl_set_lexmax(islset)
238 return self._fromislset(islset, self.symbols)
239
240 def num_parameters(self):
241 #could be useful with large, complicated polyhedrons
242 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
243 num = libisl.isl_basic_set_dim(islbset, libisl.isl_dim_set)
244 return num
245
246 def involves_dims(self, dims):
247 #could be useful with large, complicated polyhedrons
248 islset = self._toislset(self.polyhedra, self.symbols)
249 dims = sorted(dims)
250 symbols = sorted(list(self.symbols))
251 n = 0
252 if len(dims)>0:
253 for dim in dims:
254 if dim in symbols:
255 first = symbols.index(dims[0])
256 n +=1
257 else:
258 first = 0
259 else:
260 return False
261 value = bool(libisl.isl_set_involves_dims(islset, libisl.isl_dim_set, first, n))
262 libisl.isl_set_free(islset)
263 return value
264
265 _RE_COORDINATE = re.compile(r'\((?P<num>\-?\d+)\)(/(?P<den>\d+))?')
266
267 def vertices(self):
268 #returning list of verticies
269 from .polyhedra import Polyhedron
270 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
271 vertices = libisl.isl_basic_set_compute_vertices(islbset);
272 vertices = islhelper.isl_vertices_vertices(vertices)
273 points = []
274 for vertex in vertices:
275 expr = libisl.isl_vertex_get_expr(vertex)
276 coordinates = []
277 if islhelper.isl_version < '0.13':
278 constraints = islhelper.isl_basic_set_constraints(expr)
279 for constraint in constraints:
280 constant = libisl.isl_constraint_get_constant_val(constraint)
281 constant = islhelper.isl_val_to_int(constant)
282 for index, symbol in enumerate(self.symbols):
283 coefficient = libisl.isl_constraint_get_coefficient_val(constraint,
284 libisl.isl_dim_set, index)
285 coefficient = islhelper.isl_val_to_int(coefficient)
286 if coefficient != 0:
287 coordinate = -Fraction(constant, coefficient)
288 coordinates.append((symbol, coordinate))
289 else:
290 # horrible hack, find a cleaner solution
291 string = islhelper.isl_multi_aff_to_str(expr)
292 matches = self._RE_COORDINATE.finditer(string)
293 for symbol, match in zip(self.symbols, matches):
294 numerator = int(match.group('num'))
295 denominator = match.group('den')
296 denominator = 1 if denominator is None else int(denominator)
297 coordinate = Fraction(numerator, denominator)
298 coordinates.append((symbol, coordinate))
299 points.append(Point(coordinates))
300 return points
301
302 def points(self):
303 if not self.isbounded():
304 raise ValueError('domain must be bounded')
305 from .polyhedra import Universe, Eq
306 islset = self._toislset(self.polyhedra, self.symbols)
307 islpoints = islhelper.isl_set_points(islset)
308 points = []
309 for islpoint in islpoints:
310 coordinates = {}
311 for index, symbol in enumerate(self.symbols):
312 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
313 libisl.isl_dim_set, index)
314 coordinate = islhelper.isl_val_to_int(coordinate)
315 coordinates[symbol] = coordinate
316 points.append(Point(coordinates))
317 return points
318
319 def __contains__(self, point):
320 for polyhedron in self.polyhedra:
321 if point in polyhedron:
322 return True
323 return False
324
325 def subs(self, symbol, expression=None):
326 polyhedra = [polyhedron.subs(symbol, expression)
327 for polyhedron in self.polyhedra]
328 return Domain(*polyhedra)
329
330 @classmethod
331 def _fromislset(cls, islset, symbols):
332 from .polyhedra import Polyhedron
333 islset = libisl.isl_set_remove_divs(islset)
334 islbsets = isl_set_basic_sets(islset)
335 libisl.isl_set_free(islset)
336 polyhedra = []
337 for islbset in islbsets:
338 polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
339 polyhedra.append(polyhedron)
340 if len(polyhedra) == 0:
341 from .polyhedra import Empty
342 return Empty
343 elif len(polyhedra) == 1:
344 return polyhedra[0]
345 else:
346 self = object().__new__(Domain)
347 self._polyhedra = tuple(polyhedra)
348 self._symbols = cls._xsymbols(polyhedra)
349 self._dimension = len(self._symbols)
350 return self
351
352 @classmethod
353 def _toislset(cls, polyhedra, symbols):
354 polyhedron = polyhedra[0]
355 islbset = polyhedron._toislbasicset(polyhedron.equalities,
356 polyhedron.inequalities, symbols)
357 islset1 = libisl.isl_set_from_basic_set(islbset)
358 for polyhedron in polyhedra[1:]:
359 islbset = polyhedron._toislbasicset(polyhedron.equalities,
360 polyhedron.inequalities, symbols)
361 islset2 = libisl.isl_set_from_basic_set(islbset)
362 islset1 = libisl.isl_set_union(islset1, islset2)
363 return islset1
364
365 @classmethod
366 def _fromast(cls, node):
367 from .polyhedra import Polyhedron
368 if isinstance(node, ast.Module) and len(node.body) == 1:
369 return cls._fromast(node.body[0])
370 elif isinstance(node, ast.Expr):
371 return cls._fromast(node.value)
372 elif isinstance(node, ast.UnaryOp):
373 domain = cls._fromast(node.operand)
374 if isinstance(node.operand, ast.invert):
375 return Not(domain)
376 elif isinstance(node, ast.BinOp):
377 domain1 = cls._fromast(node.left)
378 domain2 = cls._fromast(node.right)
379 if isinstance(node.op, ast.BitAnd):
380 return And(domain1, domain2)
381 elif isinstance(node.op, ast.BitOr):
382 return Or(domain1, domain2)
383 elif isinstance(node, ast.Compare):
384 equalities = []
385 inequalities = []
386 left = Expression._fromast(node.left)
387 for i in range(len(node.ops)):
388 op = node.ops[i]
389 right = Expression._fromast(node.comparators[i])
390 if isinstance(op, ast.Lt):
391 inequalities.append(right - left - 1)
392 elif isinstance(op, ast.LtE):
393 inequalities.append(right - left)
394 elif isinstance(op, ast.Eq):
395 equalities.append(left - right)
396 elif isinstance(op, ast.GtE):
397 inequalities.append(left - right)
398 elif isinstance(op, ast.Gt):
399 inequalities.append(left - right - 1)
400 else:
401 break
402 left = right
403 else:
404 return Polyhedron(equalities, inequalities)
405 raise SyntaxError('invalid syntax')
406
407 _RE_BRACES = re.compile(r'^\{\s*|\s*\}$')
408 _RE_EQ = re.compile(r'([^<=>])=([^<=>])')
409 _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩')
410 _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪')
411 _RE_NOT = re.compile(r'\bnot\b|!|¬')
412 _RE_NUM_VAR = Expression._RE_NUM_VAR
413 _RE_OPERATORS = re.compile(r'(&|\||~)')
414
415 @classmethod
416 def fromstring(cls, string):
417 # remove curly brackets
418 string = cls._RE_BRACES.sub(r'', string)
419 # replace '=' by '=='
420 string = cls._RE_EQ.sub(r'\1==\2', string)
421 # replace 'and', 'or', 'not'
422 string = cls._RE_AND.sub(r' & ', string)
423 string = cls._RE_OR.sub(r' | ', string)
424 string = cls._RE_NOT.sub(r' ~', string)
425 # add implicit multiplication operators, e.g. '5x' -> '5*x'
426 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
427 # add parentheses to force precedence
428 tokens = cls._RE_OPERATORS.split(string)
429 for i, token in enumerate(tokens):
430 if i % 2 == 0:
431 token = '({})'.format(token)
432 tokens[i] = token
433 string = ''.join(tokens)
434 tree = ast.parse(string, 'eval')
435 return cls._fromast(tree)
436
437 def __repr__(self):
438 assert len(self.polyhedra) >= 2
439 strings = [repr(polyhedron) for polyhedron in self.polyhedra]
440 return 'Or({})'.format(', '.join(strings))
441
442 @classmethod
443 def fromsympy(cls, expr):
444 import sympy
445 from .polyhedra import Lt, Le, Eq, Ne, Ge, Gt
446 funcmap = {
447 sympy.And: And, sympy.Or: Or, sympy.Not: Not,
448 sympy.Lt: Lt, sympy.Le: Le,
449 sympy.Eq: Eq, sympy.Ne: Ne,
450 sympy.Ge: Ge, sympy.Gt: Gt,
451 }
452 if expr.func in funcmap:
453 args = [Domain.fromsympy(arg) for arg in expr.args]
454 return funcmap[expr.func](*args)
455 elif isinstance(expr, sympy.Expr):
456 return Expression.fromsympy(expr)
457 raise ValueError('non-domain expression: {!r}'.format(expr))
458
459 def tosympy(self):
460 import sympy
461 polyhedra = [polyhedron.tosympy() for polyhedron in polyhedra]
462 return sympy.Or(*polyhedra)
463
464
465 def And(*domains):
466 if len(domains) == 0:
467 from .polyhedra import Universe
468 return Universe
469 else:
470 return domains[0].intersection(*domains[1:])
471
472 def Or(*domains):
473 if len(domains) == 0:
474 from .polyhedra import Empty
475 return Empty
476 else:
477 return domains[0].union(*domains[1:])
478
479 def Not(domain):
480 return ~domain