Overloading for Domains.project_out(), to be improved
[linpy.git] / pypol / domains.py
index c844e55..6b47fe8 100644 (file)
@@ -5,7 +5,7 @@ import re
 from . import islhelper
 
 from .islhelper import mainctx, libisl, isl_set_basic_sets
-from .linexprs import Expression
+from .linexprs import Expression, Symbol
 
 
 __all__ = [
@@ -154,6 +154,15 @@ class Domain:
 
     def project_out(self, symbols):
         # use to remove certain variables
+        if isinstance(symbols, str):
+            symbols = symbols.replace(',', ' ').split()
+        else:
+            symbols = list(symbols)
+            for i, symbol in enumerate(symbols):
+                if isinstance(symbol, Symbol):
+                    symbols[i] = symbol.name
+                elif not isinstance(symbol, str):
+                    raise TypeError('symbols must be strings or Symbol instances')
         islset = self._toislset(self.polyhedra, self.symbols)
         # the trick is to walk symbols in reverse order, to avoid index updates
         for index, symbol in reversed(list(enumerate(self.symbols))):