20493fa4b7fd9a657044acadf9dfc5b084e7aadc
[linpy.git] / pypol / domains.py
1 import ast
2 import functools
3 import re
4
5 from . import islhelper
6
7 from .islhelper import mainctx, libisl, isl_set_basic_sets
8 from .linexprs import Expression
9
10
11 __all__ = [
12 'Domain',
13 'And', 'Or', 'Not',
14 ]
15
16
17 @functools.total_ordering
18 class Domain:
19
20 __slots__ = (
21 '_polyhedra',
22 '_symbols',
23 '_dimension',
24 )
25
26 def __new__(cls, *polyhedra):
27 from .polyhedra import Polyhedron
28 if len(polyhedra) == 1:
29 polyhedron = polyhedra[0]
30 if isinstance(polyhedron, str):
31 return cls.fromstring(polyhedron)
32 elif isinstance(polyhedron, Polyhedron):
33 return polyhedron
34 else:
35 raise TypeError('argument must be a string '
36 'or a Polyhedron instance')
37 else:
38 for polyhedron in polyhedra:
39 if not isinstance(polyhedron, Polyhedron):
40 raise TypeError('arguments must be Polyhedron instances')
41 symbols = cls._xsymbols(polyhedra)
42 islset = cls._toislset(polyhedra, symbols)
43 return cls._fromislset(islset, symbols)
44
45 @classmethod
46 def _xsymbols(cls, iterator):
47 """
48 Return the ordered tuple of symbols present in iterator.
49 """
50 symbols = set()
51 for item in iterator:
52 symbols.update(item.symbols)
53 return tuple(sorted(symbols))
54
55 @property
56 def polyhedra(self):
57 return self._polyhedra
58
59 @property
60 def symbols(self):
61 return self._symbols
62
63 @property
64 def dimension(self):
65 return self._dimension
66
67 def disjoint(self):
68 islset = self._toislset(self.polyhedra, self.symbols)
69 islset = libisl.isl_set_make_disjoint(mainctx, islset)
70 return self._fromislset(islset, self.symbols)
71
72 def isempty(self):
73 islset = self._toislset(self.polyhedra, self.symbols)
74 empty = bool(libisl.isl_set_is_empty(islset))
75 libisl.isl_set_free(islset)
76 return empty
77
78 def __bool__(self):
79 return not self.isempty()
80
81 def isuniverse(self):
82 islset = self._toislset(self.polyhedra, self.symbols)
83 universe = bool(libisl.isl_set_plain_is_universe(islset))
84 libisl.isl_set_free(islset)
85 return universe
86
87 def isbounded(self):
88 islset = self._toislset(self.polyhedra, self.symbols)
89 bounded = bool(libisl.isl_set_is_bounded(islset))
90 libisl.isl_set_free(islset)
91 return bounded
92
93 def __eq__(self, other):
94 symbols = self._xsymbols([self, other])
95 islset1 = self._toislset(self.polyhedra, symbols)
96 islset2 = other._toislset(other.polyhedra, symbols)
97 equal = bool(libisl.isl_set_is_equal(islset1, islset2))
98 libisl.isl_set_free(islset1)
99 libisl.isl_set_free(islset2)
100 return equal
101
102 def isdisjoint(self, other):
103 symbols = self._xsymbols([self, other])
104 islset1 = self._toislset(self.polyhedra, symbols)
105 islset2 = self._toislset(other.polyhedra, symbols)
106 equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
107 libisl.isl_set_free(islset1)
108 libisl.isl_set_free(islset2)
109 return equal
110
111 def issubset(self, other):
112 symbols = self._xsymbols([self, other])
113 islset1 = self._toislset(self.polyhedra, symbols)
114 islset2 = self._toislset(other.polyhedra, symbols)
115 equal = bool(libisl.isl_set_is_subset(islset1, islset2))
116 libisl.isl_set_free(islset1)
117 libisl.isl_set_free(islset2)
118 return equal
119
120 def __le__(self, other):
121 return self.issubset(other)
122
123 def __lt__(self, other):
124 symbols = self._xsymbols([self, other])
125 islset1 = self._toislset(self.polyhedra, symbols)
126 islset2 = self._toislset(other.polyhedra, symbols)
127 equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
128 libisl.isl_set_free(islset1)
129 libisl.isl_set_free(islset2)
130 return equal
131
132 def complement(self):
133 islset = self._toislset(self.polyhedra, self.symbols)
134 islset = libisl.isl_set_complement(islset)
135 return self._fromislset(islset, self.symbols)
136
137 def __invert__(self):
138 return self.complement()
139
140 def simplify(self):
141 # see isl_set_coalesce, isl_set_detect_equalities,
142 # isl_set_remove_redundancies
143 # which ones? in which order?
144 raise NotImplementedError
145
146 def polyhedral_hull(self):
147 # several types of hull are available
148 # polyhedral seems to be the more appropriate, to be checked
149 from .polyhedra import Polyhedron
150 islset = self._toislset(self.polyhedra, self.symbols)
151 islbset = libisl.isl_set_polyhedral_hull(islset)
152 return Polyhedron._fromislbasicset(islbset, self.symbols)
153
154 def project(self, symbols):
155 # not sure what isl_set_project_out actually does…
156 # use isl_set_drop_constraints_involving_dims instead?
157 raise NotImplementedError
158
159 def sample(self):
160 from .polyhedra import Polyhedron
161 islset = self._toislset(self.polyhedra, self.symbols)
162 islbset = libisl.isl_set_sample(islset)
163 return Polyhedron._fromislbasicset(islbset, self.symbols)
164
165 def intersection(self, *others):
166 if len(others) == 0:
167 return self
168 symbols = self._xsymbols((self,) + others)
169 islset1 = self._toislset(self.polyhedra, symbols)
170 for other in others:
171 islset2 = other._toislset(other.polyhedra, symbols)
172 islset1 = libisl.isl_set_intersect(islset1, islset2)
173 return self._fromislset(islset1, symbols)
174
175 def __and__(self, other):
176 return self.intersection(other)
177
178 def union(self, *others):
179 if len(others) == 0:
180 return self
181 symbols = self._xsymbols((self,) + others)
182 islset1 = self._toislset(self.polyhedra, symbols)
183 for other in others:
184 islset2 = other._toislset(other.polyhedra, symbols)
185 islset1 = libisl.isl_set_union(islset1, islset2)
186 return self._fromislset(islset1, symbols)
187
188 def __or__(self, other):
189 return self.union(other)
190
191 def __add__(self, other):
192 return self.union(other)
193
194 def difference(self, other):
195 symbols = self._xsymbols([self, other])
196 islset1 = self._toislset(self.polyhedra, symbols)
197 islset2 = other._toislset(other.polyhedra, symbols)
198 islset = libisl.isl_set_subtract(islset1, islset2)
199 return self._fromislset(islset, symbols)
200
201 def __sub__(self, other):
202 return self.difference(other)
203
204 def lexmin(self):
205 islset = self._toislset(self.polyhedra, self.symbols)
206 islset = libisl.isl_set_lexmin(islset)
207 return self._fromislset(islset, self.symbols)
208
209 def lexmax(self):
210 islset = self._toislset(self.polyhedra, self.symbols)
211 islset = libisl.isl_set_lexmax(islset)
212 return self._fromislset(islset, self.symbols)
213
214 @classmethod
215 def _fromislset(cls, islset, symbols):
216 from .polyhedra import Polyhedron
217 islset = libisl.isl_set_remove_divs(islset)
218 islbsets = isl_set_basic_sets(islset)
219 libisl.isl_set_free(islset)
220 polyhedra = []
221 for islbset in islbsets:
222 polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
223 polyhedra.append(polyhedron)
224 if len(polyhedra) == 0:
225 from .polyhedra import Empty
226 return Empty
227 elif len(polyhedra) == 1:
228 return polyhedra[0]
229 else:
230 self = object().__new__(Domain)
231 self._polyhedra = tuple(polyhedra)
232 self._symbols = cls._xsymbols(polyhedra)
233 self._dimension = len(self._symbols)
234 return self
235
236 def _toislset(cls, polyhedra, symbols):
237 polyhedron = polyhedra[0]
238 islbset = polyhedron._toislbasicset(polyhedron.equalities,
239 polyhedron.inequalities, symbols)
240 islset1 = libisl.isl_set_from_basic_set(islbset)
241 for polyhedron in polyhedra[1:]:
242 islbset = polyhedron._toislbasicset(polyhedron.equalities,
243 polyhedron.inequalities, symbols)
244 islset2 = libisl.isl_set_from_basic_set(islbset)
245 islset1 = libisl.isl_set_union(islset1, islset2)
246 return islset1
247
248 @classmethod
249 def _fromast(cls, node):
250 from .polyhedra import Polyhedron
251 if isinstance(node, ast.Module) and len(node.body) == 1:
252 return cls._fromast(node.body[0])
253 elif isinstance(node, ast.Expr):
254 return cls._fromast(node.value)
255 elif isinstance(node, ast.UnaryOp):
256 domain = cls._fromast(node.operand)
257 if isinstance(node.operand, ast.invert):
258 return Not(domain)
259 elif isinstance(node, ast.BinOp):
260 domain1 = cls._fromast(node.left)
261 domain2 = cls._fromast(node.right)
262 if isinstance(node.op, ast.BitAnd):
263 return And(domain1, domain2)
264 elif isinstance(node.op, ast.BitOr):
265 return Or(domain1, domain2)
266 elif isinstance(node, ast.Compare):
267 equalities = []
268 inequalities = []
269 left = Expression._fromast(node.left)
270 for i in range(len(node.ops)):
271 op = node.ops[i]
272 right = Expression._fromast(node.comparators[i])
273 if isinstance(op, ast.Lt):
274 inequalities.append(right - left - 1)
275 elif isinstance(op, ast.LtE):
276 inequalities.append(right - left)
277 elif isinstance(op, ast.Eq):
278 equalities.append(left - right)
279 elif isinstance(op, ast.GtE):
280 inequalities.append(left - right)
281 elif isinstance(op, ast.Gt):
282 inequalities.append(left - right - 1)
283 else:
284 break
285 left = right
286 else:
287 return Polyhedron(equalities, inequalities)
288 raise SyntaxError('invalid syntax')
289
290 @classmethod
291 def fromstring(cls, string):
292 # remove brackets
293 string = re.sub(r'^\{\s*|\s*\}$', '', string)
294 # replace '=' by '=='
295 string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
296 # replace 'and', 'or', 'not'
297 string = re.sub(r'\band\b|,|&&|/\\|∧|∩', r' & ', string)
298 string = re.sub(r'\bor\b|;|\|\||\\/|∨|∪', r' | ', string)
299 string = re.sub(r'\bnot\b|!|¬', r' ~', string)
300 tokens = re.split(r'(&|\||~)', string)
301 for i, token in enumerate(tokens):
302 if i % 2 == 0:
303 # add implicit multiplication operators, e.g. '5x' -> '5*x'
304 token = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', token)
305 token = '({})'.format(token)
306 tokens[i] = token
307 string = ''.join(tokens)
308 tree = ast.parse(string)
309 return cls._fromast(tree)
310
311 def __repr__(self):
312 assert len(self.polyhedra) >= 2
313 strings = [repr(polyhedron) for polyhedron in self.polyhedra]
314 return 'Or({})'.format(', '.join(strings))
315
316 @classmethod
317 def fromsympy(cls, expr):
318 raise NotImplementedError
319
320 def tosympy(self):
321 raise NotImplementedError
322
323
324 def And(*domains):
325 if len(domains) == 0:
326 from .polyhedra import Universe
327 return Universe
328 else:
329 return domains[0].intersection(*domains[1:])
330
331 def Or(*domains):
332 if len(domains) == 0:
333 from .polyhedra import Empty
334 return Empty
335 else:
336 return domains[0].union(*domains[1:])
337
338 def Not(domain):
339 return ~domain