Use displaystyle in _repr_latex_
[linpy.git] / pypol / domains.py
1 import ast
2 import functools
3 import re
4
5 from fractions import Fraction
6
7 from . import islhelper
8 from .islhelper import mainctx, libisl
9 from .geometry import GeometricObject, Point
10 from .linexprs import Expression, Symbol
11
12
13 __all__ = [
14 'Domain',
15 'And', 'Or', 'Not',
16 ]
17
18
19 @functools.total_ordering
20 class Domain(GeometricObject):
21
22 __slots__ = (
23 '_polyhedra',
24 '_symbols',
25 '_dimension',
26 )
27
28 def __new__(cls, *polyhedra):
29 from .polyhedra import Polyhedron
30 if len(polyhedra) == 1:
31 argument = polyhedra[0]
32 if isinstance(argument, str):
33 return cls.fromstring(argument)
34 elif isinstance(argument, GeometricObject):
35 return argument.aspolyhedron()
36 else:
37 raise TypeError('argument must be a string '
38 'or a GeometricObject instance')
39 else:
40 for polyhedron in polyhedra:
41 if not isinstance(polyhedron, Polyhedron):
42 raise TypeError('arguments must be Polyhedron instances')
43 symbols = cls._xsymbols(polyhedra)
44 islset = cls._toislset(polyhedra, symbols)
45 return cls._fromislset(islset, symbols)
46
47 @classmethod
48 def _xsymbols(cls, iterator):
49 """
50 Return the ordered tuple of symbols present in iterator.
51 """
52 symbols = set()
53 for item in iterator:
54 symbols.update(item.symbols)
55 return tuple(sorted(symbols, key=Symbol.sortkey))
56
57 @property
58 def polyhedra(self):
59 return self._polyhedra
60
61 @property
62 def symbols(self):
63 return self._symbols
64
65 @property
66 def dimension(self):
67 return self._dimension
68
69 def disjoint(self):
70 """
71 Returns this set as disjoint.
72 """
73 islset = self._toislset(self.polyhedra, self.symbols)
74 islset = libisl.isl_set_make_disjoint(mainctx, islset)
75 return self._fromislset(islset, self.symbols)
76
77 def isempty(self):
78 """
79 Returns true if this set is an Empty set.
80 """
81 islset = self._toislset(self.polyhedra, self.symbols)
82 empty = bool(libisl.isl_set_is_empty(islset))
83 libisl.isl_set_free(islset)
84 return empty
85
86 def __bool__(self):
87 return not self.isempty()
88
89 def isuniverse(self):
90 """
91 Returns true if this set is the Universe set.
92 """
93 islset = self._toislset(self.polyhedra, self.symbols)
94 universe = bool(libisl.isl_set_plain_is_universe(islset))
95 libisl.isl_set_free(islset)
96 return universe
97
98 def isbounded(self):
99 """
100 Returns true if this set is bounded.
101 """
102 islset = self._toislset(self.polyhedra, self.symbols)
103 bounded = bool(libisl.isl_set_is_bounded(islset))
104 libisl.isl_set_free(islset)
105 return bounded
106
107 def __eq__(self, other):
108 """
109 Returns true if two sets are equal.
110 """
111 symbols = self._xsymbols([self, other])
112 islset1 = self._toislset(self.polyhedra, symbols)
113 islset2 = other._toislset(other.polyhedra, symbols)
114 equal = bool(libisl.isl_set_is_equal(islset1, islset2))
115 libisl.isl_set_free(islset1)
116 libisl.isl_set_free(islset2)
117 return equal
118
119 def isdisjoint(self, other):
120 """
121 Return True if two sets have a null intersection.
122 """
123 symbols = self._xsymbols([self, other])
124 islset1 = self._toislset(self.polyhedra, symbols)
125 islset2 = self._toislset(other.polyhedra, symbols)
126 equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
127 libisl.isl_set_free(islset1)
128 libisl.isl_set_free(islset2)
129 return equal
130
131 def issubset(self, other):
132 """
133 Report whether another set contains this set.
134 """
135 symbols = self._xsymbols([self, other])
136 islset1 = self._toislset(self.polyhedra, symbols)
137 islset2 = self._toislset(other.polyhedra, symbols)
138 equal = bool(libisl.isl_set_is_subset(islset1, islset2))
139 libisl.isl_set_free(islset1)
140 libisl.isl_set_free(islset2)
141 return equal
142
143 def __le__(self, other):
144 """
145 Returns true if this set is less than or equal to another set.
146 """
147 return self.issubset(other)
148
149 def __lt__(self, other):
150 """
151 Returns true if this set is less than another set.
152 """
153 symbols = self._xsymbols([self, other])
154 islset1 = self._toislset(self.polyhedra, symbols)
155 islset2 = self._toislset(other.polyhedra, symbols)
156 equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
157 libisl.isl_set_free(islset1)
158 libisl.isl_set_free(islset2)
159 return equal
160
161 def complement(self):
162 """
163 Returns the complement of this set.
164 """
165 islset = self._toislset(self.polyhedra, self.symbols)
166 islset = libisl.isl_set_complement(islset)
167 return self._fromislset(islset, self.symbols)
168
169 def __invert__(self):
170 """
171 Returns the complement of this set.
172 """
173 return self.complement()
174
175 def simplify(self):
176 """
177 Returns a set without redundant constraints.
178 """
179 islset = self._toislset(self.polyhedra, self.symbols)
180 islset = libisl.isl_set_remove_redundancies(islset)
181 return self._fromislset(islset, self.symbols)
182
183 def aspolyhedron(self):
184 """
185 Returns polyhedral hull of set.
186 """
187 from .polyhedra import Polyhedron
188 islset = self._toislset(self.polyhedra, self.symbols)
189 islbset = libisl.isl_set_polyhedral_hull(islset)
190 return Polyhedron._fromislbasicset(islbset, self.symbols)
191
192 def asdomain(self):
193 return self
194
195 def project(self, dims):
196 """
197 Return new set with given dimensions removed.
198 """
199 islset = self._toislset(self.polyhedra, self.symbols)
200 n = 0
201 for index, symbol in reversed(list(enumerate(self.symbols))):
202 if symbol in dims:
203 n += 1
204 elif n > 0:
205 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n)
206 n = 0
207 if n > 0:
208 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n)
209 dims = [symbol for symbol in self.symbols if symbol not in dims]
210 return Domain._fromislset(islset, dims)
211
212 def sample(self):
213 """
214 Returns a single subset of the input.
215 """
216 islset = self._toislset(self.polyhedra, self.symbols)
217 islpoint = libisl.isl_set_sample_point(islset)
218 if bool(libisl.isl_point_is_void(islpoint)):
219 libisl.isl_point_free(islpoint)
220 raise ValueError('domain must be non-empty')
221 point = {}
222 for index, symbol in enumerate(self.symbols):
223 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
224 libisl.isl_dim_set, index)
225 coordinate = islhelper.isl_val_to_int(coordinate)
226 point[symbol] = coordinate
227 libisl.isl_point_free(islpoint)
228 return point
229
230 def intersection(self, *others):
231 """
232 Return the intersection of two sets as a new set.
233 """
234 if len(others) == 0:
235 return self
236 symbols = self._xsymbols((self,) + others)
237 islset1 = self._toislset(self.polyhedra, symbols)
238 for other in others:
239 islset2 = other._toislset(other.polyhedra, symbols)
240 islset1 = libisl.isl_set_intersect(islset1, islset2)
241 return self._fromislset(islset1, symbols)
242
243 def __and__(self, other):
244 """
245 Return the intersection of two sets as a new set.
246 """
247 return self.intersection(other)
248
249 def union(self, *others):
250 """
251 Return the union of sets as a new set.
252 """
253 if len(others) == 0:
254 return self
255 symbols = self._xsymbols((self,) + others)
256 islset1 = self._toislset(self.polyhedra, symbols)
257 for other in others:
258 islset2 = other._toislset(other.polyhedra, symbols)
259 islset1 = libisl.isl_set_union(islset1, islset2)
260 return self._fromislset(islset1, symbols)
261
262 def __or__(self, other):
263 """
264 Return a new set with elements from both sets.
265 """
266 return self.union(other)
267
268 def __add__(self, other):
269 """
270 Return new set containing all elements in both sets.
271 """
272 return self.union(other)
273
274 def difference(self, other):
275 """
276 Return the difference of two sets as a new set.
277 """
278 symbols = self._xsymbols([self, other])
279 islset1 = self._toislset(self.polyhedra, symbols)
280 islset2 = other._toislset(other.polyhedra, symbols)
281 islset = libisl.isl_set_subtract(islset1, islset2)
282 return self._fromislset(islset, symbols)
283
284 def __sub__(self, other):
285 """
286 Return the difference of two sets as a new set.
287 """
288 return self.difference(other)
289
290 def lexmin(self):
291 """
292 Return a new set containing the lexicographic minimum of the elements in the set.
293 """
294 islset = self._toislset(self.polyhedra, self.symbols)
295 islset = libisl.isl_set_lexmin(islset)
296 return self._fromislset(islset, self.symbols)
297
298 def lexmax(self):
299 """
300 Return a new set containing the lexicographic maximum of the elements in the set.
301 """
302 islset = self._toislset(self.polyhedra, self.symbols)
303 islset = libisl.isl_set_lexmax(islset)
304 return self._fromislset(islset, self.symbols)
305
306 def num_parameters(self):
307 """
308 Return the total number of parameters, input, output or set dimensions.
309 """
310 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
311 num = libisl.isl_basic_set_dim(islbset, libisl.isl_dim_set)
312 return num
313
314 def involves_dims(self, dims):
315 """
316 Returns true if set depends on given dimensions.
317 """
318 islset = self._toislset(self.polyhedra, self.symbols)
319 dims = sorted(dims)
320 symbols = sorted(list(self.symbols))
321 n = 0
322 if len(dims)>0:
323 for dim in dims:
324 if dim in symbols:
325 first = symbols.index(dims[0])
326 n +=1
327 else:
328 first = 0
329 else:
330 return False
331 value = bool(libisl.isl_set_involves_dims(islset, libisl.isl_dim_set, first, n))
332 libisl.isl_set_free(islset)
333 return value
334
335 _RE_COORDINATE = re.compile(r'\((?P<num>\-?\d+)\)(/(?P<den>\d+))?')
336
337 def vertices(self):
338 """
339 Return a list of vertices for this Polygon.
340 """
341 from .polyhedra import Polyhedron
342 if not self.isbounded():
343 raise ValueError('domain must be bounded')
344 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
345 vertices = libisl.isl_basic_set_compute_vertices(islbset);
346 vertices = islhelper.isl_vertices_vertices(vertices)
347 points = []
348 for vertex in vertices:
349 expr = libisl.isl_vertex_get_expr(vertex)
350 coordinates = []
351 if islhelper.isl_version < '0.13':
352 constraints = islhelper.isl_basic_set_constraints(expr)
353 for constraint in constraints:
354 constant = libisl.isl_constraint_get_constant_val(constraint)
355 constant = islhelper.isl_val_to_int(constant)
356 for index, symbol in enumerate(self.symbols):
357 coefficient = libisl.isl_constraint_get_coefficient_val(constraint,
358 libisl.isl_dim_set, index)
359 coefficient = islhelper.isl_val_to_int(coefficient)
360 if coefficient != 0:
361 coordinate = -Fraction(constant, coefficient)
362 coordinates.append((symbol, coordinate))
363 else:
364
365 # horrible hack, find a cleaner solution
366 string = islhelper.isl_multi_aff_to_str(expr)
367 matches = self._RE_COORDINATE.finditer(string)
368 for symbol, match in zip(self.symbols, matches):
369 numerator = int(match.group('num'))
370 denominator = match.group('den')
371 denominator = 1 if denominator is None else int(denominator)
372 coordinate = Fraction(numerator, denominator)
373 coordinates.append((symbol, coordinate))
374 points.append(Point(coordinates))
375 return points
376
377 def points(self):
378 """
379 Returns the points contained in the set.
380 """
381 if not self.isbounded():
382 raise ValueError('domain must be bounded')
383 from .polyhedra import Universe, Eq
384 islset = self._toislset(self.polyhedra, self.symbols)
385 islpoints = islhelper.isl_set_points(islset)
386 points = []
387 for islpoint in islpoints:
388 coordinates = {}
389 for index, symbol in enumerate(self.symbols):
390 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
391 libisl.isl_dim_set, index)
392 coordinate = islhelper.isl_val_to_int(coordinate)
393 coordinates[symbol] = coordinate
394 points.append(Point(coordinates))
395 return points
396
397 def __contains__(self, point):
398 for polyhedron in self.polyhedra:
399 if point in polyhedron:
400 return True
401 return False
402
403 def subs(self, symbol, expression=None):
404 polyhedra = [polyhedron.subs(symbol, expression)
405 for polyhedron in self.polyhedra]
406 return Domain(*polyhedra)
407
408 @classmethod
409 def _fromislset(cls, islset, symbols):
410 from .polyhedra import Polyhedron
411 islset = libisl.isl_set_remove_divs(islset)
412 islbsets = islhelper.isl_set_basic_sets(islset)
413 libisl.isl_set_free(islset)
414 polyhedra = []
415 for islbset in islbsets:
416 polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
417 polyhedra.append(polyhedron)
418 if len(polyhedra) == 0:
419 from .polyhedra import Empty
420 return Empty
421 elif len(polyhedra) == 1:
422 return polyhedra[0]
423 else:
424 self = object().__new__(Domain)
425 self._polyhedra = tuple(polyhedra)
426 self._symbols = cls._xsymbols(polyhedra)
427 self._dimension = len(self._symbols)
428 return self
429
430 @classmethod
431 def _toislset(cls, polyhedra, symbols):
432 polyhedron = polyhedra[0]
433 islbset = polyhedron._toislbasicset(polyhedron.equalities,
434 polyhedron.inequalities, symbols)
435 islset1 = libisl.isl_set_from_basic_set(islbset)
436 for polyhedron in polyhedra[1:]:
437 islbset = polyhedron._toislbasicset(polyhedron.equalities,
438 polyhedron.inequalities, symbols)
439 islset2 = libisl.isl_set_from_basic_set(islbset)
440 islset1 = libisl.isl_set_union(islset1, islset2)
441 return islset1
442
443 @classmethod
444 def _fromast(cls, node):
445 from .polyhedra import Polyhedron
446 if isinstance(node, ast.Module) and len(node.body) == 1:
447 return cls._fromast(node.body[0])
448 elif isinstance(node, ast.Expr):
449 return cls._fromast(node.value)
450 elif isinstance(node, ast.UnaryOp):
451 domain = cls._fromast(node.operand)
452 if isinstance(node.operand, ast.invert):
453 return Not(domain)
454 elif isinstance(node, ast.BinOp):
455 domain1 = cls._fromast(node.left)
456 domain2 = cls._fromast(node.right)
457 if isinstance(node.op, ast.BitAnd):
458 return And(domain1, domain2)
459 elif isinstance(node.op, ast.BitOr):
460 return Or(domain1, domain2)
461 elif isinstance(node, ast.Compare):
462 equalities = []
463 inequalities = []
464 left = Expression._fromast(node.left)
465 for i in range(len(node.ops)):
466 op = node.ops[i]
467 right = Expression._fromast(node.comparators[i])
468 if isinstance(op, ast.Lt):
469 inequalities.append(right - left - 1)
470 elif isinstance(op, ast.LtE):
471 inequalities.append(right - left)
472 elif isinstance(op, ast.Eq):
473 equalities.append(left - right)
474 elif isinstance(op, ast.GtE):
475 inequalities.append(left - right)
476 elif isinstance(op, ast.Gt):
477 inequalities.append(left - right - 1)
478 else:
479 break
480 left = right
481 else:
482 return Polyhedron(equalities, inequalities)
483 raise SyntaxError('invalid syntax')
484
485 _RE_BRACES = re.compile(r'^\{\s*|\s*\}$')
486 _RE_EQ = re.compile(r'([^<=>])=([^<=>])')
487 _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩')
488 _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪')
489 _RE_NOT = re.compile(r'\bnot\b|!|¬')
490 _RE_NUM_VAR = Expression._RE_NUM_VAR
491 _RE_OPERATORS = re.compile(r'(&|\||~)')
492
493 @classmethod
494 def fromstring(cls, string):
495 # remove curly brackets
496 string = cls._RE_BRACES.sub(r'', string)
497 # replace '=' by '=='
498 string = cls._RE_EQ.sub(r'\1==\2', string)
499 # replace 'and', 'or', 'not'
500 string = cls._RE_AND.sub(r' & ', string)
501 string = cls._RE_OR.sub(r' | ', string)
502 string = cls._RE_NOT.sub(r' ~', string)
503 # add implicit multiplication operators, e.g. '5x' -> '5*x'
504 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
505 # add parentheses to force precedence
506 tokens = cls._RE_OPERATORS.split(string)
507 for i, token in enumerate(tokens):
508 if i % 2 == 0:
509 token = '({})'.format(token)
510 tokens[i] = token
511 string = ''.join(tokens)
512 tree = ast.parse(string, 'eval')
513 return cls._fromast(tree)
514
515 def __repr__(self):
516 assert len(self.polyhedra) >= 2
517 strings = [repr(polyhedron) for polyhedron in self.polyhedra]
518 return 'Or({})'.format(', '.join(strings))
519
520 def _repr_latex_(self):
521 strings = []
522 for polyhedron in self.polyhedra:
523 strings.append('({})'.format(polyhedron._repr_latex_().strip('$')))
524 return '$${}$$'.format(' \\vee '.join(strings))
525
526 @classmethod
527 def fromsympy(cls, expr):
528 import sympy
529 from .polyhedra import Lt, Le, Eq, Ne, Ge, Gt
530 funcmap = {
531 sympy.And: And, sympy.Or: Or, sympy.Not: Not,
532 sympy.Lt: Lt, sympy.Le: Le,
533 sympy.Eq: Eq, sympy.Ne: Ne,
534 sympy.Ge: Ge, sympy.Gt: Gt,
535 }
536 if expr.func in funcmap:
537 args = [Domain.fromsympy(arg) for arg in expr.args]
538 return funcmap[expr.func](*args)
539 elif isinstance(expr, sympy.Expr):
540 return Expression.fromsympy(expr)
541 raise ValueError('non-domain expression: {!r}'.format(expr))
542
543 def tosympy(self):
544 import sympy
545 polyhedra = [polyhedron.tosympy() for polyhedron in polyhedra]
546 return sympy.Or(*polyhedra)
547
548
549 def And(*domains):
550 """
551 Return the intersection of two sets as a new set.
552 """
553 if len(domains) == 0:
554 from .polyhedra import Universe
555 return Universe
556 else:
557 return domains[0].intersection(*domains[1:])
558
559 def Or(*domains):
560 """
561 Return the union of sets as a new set.
562 """
563 if len(domains) == 0:
564 from .polyhedra import Empty
565 return Empty
566 else:
567 return domains[0].union(*domains[1:])
568
569 def Not(domain):
570 """
571 Returns the complement of this set.
572 """
573 return ~domain