Split linear.py and add domains
authorVivien Maisonneuve <v.maisonneuve@gmail.com>
Wed, 25 Jun 2014 11:56:34 +0000 (13:56 +0200)
committerVivien Maisonneuve <v.maisonneuve@gmail.com>
Wed, 25 Jun 2014 12:12:33 +0000 (14:12 +0200)
17 files changed:
Makefile
pypol/.gitignore
pypol/__init__.py
pypol/_isl.c [deleted file]
pypol/_islhelper.c [new file with mode: 0644]
pypol/domains.py [new file with mode: 0644]
pypol/isl.py [deleted file]
pypol/islhelper.py [new file with mode: 0644]
pypol/linear.py [deleted file]
pypol/linexprs.py [new file with mode: 0644]
pypol/polyhedra.py [new file with mode: 0644]
pypol/tests/__init__.py [moved from tests/__init__.py with 100% similarity]
pypol/tests/test_domains.py [new file with mode: 0644]
pypol/tests/test_linexprs.py [moved from tests/test_linear.py with 80% similarity]
pypol/tests/test_polyhedra.py [new file with mode: 0644]
setup.py
tests/test_isl.py [deleted file]

index f2f5e9c..687d36b 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -25,6 +25,5 @@ test: build
 
 .PHONY: clean
 clean:
-       $(RM) build dist MANIFEST venv
-       $(RM) $(NAME).egg-info $(NAME)/_isl.*.so $(NAME)/__pycache__
-       $(RM) tests/__pycache__
+       $(RM) build dist MANIFEST venv $(NAME).egg-info
+       $(RM) $(NAME)/_islhelper.*.so $(NAME)/__pycache__ $(NAME)/tests/__pycache__
index 7951d19..fd69600 100644 (file)
@@ -1 +1 @@
-/_isl.*.so
+/_islhelper.*.so
index 451d7f1..fc70bea 100644 (file)
@@ -2,13 +2,13 @@
 A polyhedral library based on ISL.
 """
 
-from .linear import Polyhedron, Constant, Symbol, symbols
-from .linear import eq, le, lt, ge, gt
-from .linear import Empty, Universe
+from .linexprs import Expression, Constant, Symbol, symbols
+from .polyhedra import Polyhedron, Eq, Ne, Le, Lt, Ge, Gt, Ne, Empty, Universe
+from .domains import Domain, And, Or, Not
 
 
 __all__ = [
-    'Polyhedron', 'Constant', 'Symbol', 'symbols',
-    'eq', 'le', 'lt', 'ge', 'gt',
-    'Empty', 'Universe'
+    'Expression', 'Constant', 'Symbol', 'symbols',
+    'Polyhedron', 'Eq', 'Ne', 'Le', 'Lt', 'Ge', 'Gt', 'Empty', 'Universe',
+    'Domain', 'And', 'Or', 'Not',
 ]
diff --git a/pypol/_isl.c b/pypol/_isl.c
deleted file mode 100644 (file)
index 1b086dc..0000000
+++ /dev/null
@@ -1,89 +0,0 @@
-#include <Python.h>
-
-#include <isl/constraint.h>
-#include <isl/set.h>
-
-struct pointer_list {
-    int cursor;
-    PyObject *pointers;
-};
-typedef struct pointer_list pointer_list;
-
-static int pointer_list_append_constraint(isl_constraint *c, void *user) {
-    pointer_list *list;
-    PyObject *value;
-
-    list = (pointer_list *) user;
-    value = PyLong_FromVoidPtr(c);
-    if (value == NULL) {
-        return -1;
-    }
-    PyList_SET_ITEM(list->pointers, list->cursor++, value);
-    return 0;
-}
-
-static PyObject * basic_set_constraints(PyObject *self, PyObject* args) {
-    long ptr;
-    isl_basic_set *bset;
-    int n;
-    PyObject *pointers;
-    pointer_list *list;
-
-    if (!PyArg_ParseTuple(args, "l", &ptr))
-        return NULL;
-    bset = (isl_basic_set*) ptr;
-    n = isl_basic_set_n_constraint(bset);
-    if (n == -1) {
-        PyErr_SetString(PyExc_RuntimeError,
-            "an error occurred in isl_basic_set_n_constraint");
-        return NULL;
-    }
-    pointers = PyList_New(n);
-    if (pointers == NULL) {
-        return NULL;
-    }
-    list = malloc(sizeof(pointer_list));
-    if (list == NULL) {
-        Py_DECREF(pointers);
-        return PyErr_NoMemory();
-    }
-    list->cursor = 0;
-    list->pointers = pointers;
-    n = isl_basic_set_foreach_constraint(bset, pointer_list_append_constraint,
-        list);
-    free(list);
-    if (n == -1) {
-        PyErr_SetString(PyExc_RuntimeError,
-            "an error occurred in isl_basic_set_foreach_constraint");
-        Py_DECREF(pointers);
-        return NULL;
-    }
-    return pointers;
-}
-
-static PyMethodDef _isl_methods[] = {
-    {"basic_set_constraints", basic_set_constraints, METH_VARARGS, NULL},
-    {NULL, NULL, 0, NULL}
-};
-
-static struct PyModuleDef _islmodule = {
-    PyModuleDef_HEAD_INIT,
-    "_isl",
-    NULL,
-    0,
-    _isl_methods
-};
-
-PyMODINIT_FUNC PyInit__isl(void) {
-    PyObject *m;
-    m = PyModule_Create(&_islmodule);
-    if (m == NULL) {
-        return NULL;
-    }
-
-    if (PyModule_AddObject(m, "dim_set", PyLong_FromLong(isl_dim_set)) == -1) {
-        return NULL;
-    }
-
-    return m;
-}
diff --git a/pypol/_islhelper.c b/pypol/_islhelper.c
new file mode 100644 (file)
index 0000000..f8c03e3
--- /dev/null
@@ -0,0 +1,148 @@
+#include <Python.h>
+
+#include <isl/constraint.h>
+#include <isl/set.h>
+
+
+struct pointer_list {
+    int cursor;
+    PyObject *pointers;
+};
+
+typedef struct pointer_list pointer_list;
+
+
+static int pointer_list_append_constraint(isl_constraint *c, void *user) {
+    pointer_list *list;
+    PyObject *value;
+
+    list = (pointer_list *) user;
+    value = PyLong_FromVoidPtr(c);
+    if (value == NULL) {
+        return -1;
+    }
+    PyList_SET_ITEM(list->pointers, list->cursor++, value);
+    return 0;
+}
+
+static PyObject * isl_basic_set_constraints(PyObject *self, PyObject* args) {
+    long ptr;
+    isl_basic_set *bset;
+    int n;
+    PyObject *pointers;
+    pointer_list *list;
+
+    if (!PyArg_ParseTuple(args, "l", &ptr)) {
+        return NULL;
+    }
+    bset = (isl_basic_set *) ptr;
+    n = isl_basic_set_n_constraint(bset);
+    if (n == -1) {
+        PyErr_SetString(PyExc_RuntimeError,
+            "an error occurred in isl_basic_set_n_constraint");
+        return NULL;
+    }
+    pointers = PyList_New(n);
+    if (pointers == NULL) {
+        return NULL;
+    }
+    list = malloc(sizeof(pointer_list));
+    if (list == NULL) {
+        Py_DECREF(pointers);
+        return PyErr_NoMemory();
+    }
+    list->cursor = 0;
+    list->pointers = pointers;
+    n = isl_basic_set_foreach_constraint(bset, pointer_list_append_constraint,
+        list);
+    free(list);
+    if (n == -1) {
+        PyErr_SetString(PyExc_RuntimeError,
+            "an error occurred in isl_basic_set_foreach_constraint");
+        Py_DECREF(pointers);
+        return NULL;
+    }
+    return pointers;
+}
+
+
+static int pointer_list_append_basic_set(isl_basic_set *bset, void *user) {
+    pointer_list *list;
+    PyObject *value;
+
+    list = (pointer_list *) user;
+    value = PyLong_FromVoidPtr(bset);
+    if (value == NULL) {
+        return -1;
+    }
+    PyList_SET_ITEM(list->pointers, list->cursor++, value);
+    return 0;
+}
+
+static PyObject * isl_set_basic_sets(PyObject *self, PyObject *args) {
+    long ptr;
+    isl_set *set;
+    int n;
+    PyObject *pointers;
+    pointer_list *list;
+
+    if (!PyArg_ParseTuple(args, "l", &ptr)) {
+        return NULL;
+    }
+    set = (isl_set *) ptr;
+    n = isl_set_n_basic_set(set);
+    if (n == -1) {
+        PyErr_SetString(PyExc_RuntimeError,
+            "an error occurred in isl_set_n_basic_set");
+        return NULL;
+    }
+    pointers = PyList_New(n);
+    if (pointers == NULL) {
+        return NULL;
+    }
+    list = malloc(sizeof(pointer_list));
+    if (list == NULL) {
+        Py_DECREF(pointers);
+        return PyErr_NoMemory();
+    }
+    list->cursor = 0;
+    list->pointers = pointers;
+    n = isl_set_foreach_basic_set(set, pointer_list_append_basic_set, list);
+    free(list);
+    if (n == -1) {
+        PyErr_SetString(PyExc_RuntimeError,
+            "an error occurred in isl_set_foreach_basic_set");
+        Py_DECREF(pointers);
+        return NULL;
+    }
+    return pointers;
+}
+
+
+static PyMethodDef _islhelper_methods[] = {
+    {"isl_basic_set_constraints", isl_basic_set_constraints, METH_VARARGS, NULL},
+    {"isl_set_basic_sets", isl_set_basic_sets, METH_VARARGS, NULL},
+    {NULL, NULL, 0, NULL}
+};
+
+static struct PyModuleDef _islhelpermodule = {
+    PyModuleDef_HEAD_INIT,
+    "_islhelper",
+    NULL,
+    0,
+    _islhelper_methods
+};
+
+PyMODINIT_FUNC PyInit__islhelper(void) {
+    PyObject *m;
+    m = PyModule_Create(&_islhelpermodule);
+    if (m == NULL) {
+        return NULL;
+    }
+
+    if (PyModule_AddObject(m, "dim_set", PyLong_FromLong(isl_dim_set)) == -1) {
+        return NULL;
+    }
+
+    return m;
+}
diff --git a/pypol/domains.py b/pypol/domains.py
new file mode 100644 (file)
index 0000000..fd588b7
--- /dev/null
@@ -0,0 +1,271 @@
+import functools
+
+from . import islhelper
+
+from .islhelper import mainctx, libisl, isl_set_basic_sets
+
+
+__all__ = [
+    'Domain',
+    'And', 'Or', 'Not',
+]
+
+
+@functools.total_ordering
+class Domain:
+
+    __slots__ = (
+        '_polyhedra',
+        '_symbols',
+        '_dimension',
+    )
+
+    def __new__(cls, *polyhedra):
+        from .polyhedra import Polyhedron
+        if len(polyhedra) == 1:
+            polyhedron = polyhedra[0]
+            if isinstance(polyhedron, str):
+                return cls.fromstring(polyhedron)
+            elif isinstance(polyhedron, Polyhedron):
+                return polyhedron
+            else:
+                raise TypeError('argument must be a string '
+                    'or a Polyhedron instance')
+        else:
+            for polyhedron in polyhedra:
+                if not isinstance(polyhedron, Polyhedron):
+                    raise TypeError('arguments must be Polyhedron instances')
+            symbols = cls._xsymbols(polyhedra)
+            islset = cls._toislset(polyhedra, symbols)
+            return cls._fromislset(islset, symbols)
+
+    @classmethod
+    def _xsymbols(cls, iterator):
+        """
+        Return the ordered tuple of symbols present in iterator.
+        """
+        symbols = set()
+        for item in iterator:
+            symbols.update(item.symbols)
+        return tuple(sorted(symbols))
+
+    @property
+    def polyhedra(self):
+        return self._polyhedra
+
+    @property
+    def symbols(self):
+        return self._symbols
+
+    @property
+    def dimension(self):
+        return self._dimension
+
+    def disjoint(self):
+        islset = self._toislset(self.polyhedra, self.symbols)
+        islset = libisl.isl_set_make_disjoint(mainctx, islset)
+        return self._fromislset(islset, self.symbols)
+
+    def isempty(self):
+        islset = self._toislset(self.polyhedra, self.symbols)
+        empty = bool(libisl.isl_set_is_empty(islset))
+        libisl.isl_set_free(islset)
+        return empty
+
+    def __bool__(self):
+        return not self.isempty()
+
+    def isuniverse(self):
+        islset = self._toislset(self.polyhedra, self.symbols)
+        universe = bool(libisl.isl_set_plain_is_universe(islset))
+        libisl.isl_set_free(islset)
+        return universe
+
+    def __eq__(self, other):
+        symbols = self._xsymbols([self, other])
+        islset1 = self._toislset(self.polyhedra, symbols)
+        islset2 = other._toislset(other.polyhedra, symbols)
+        equal = bool(libisl.isl_set_is_equal(islset1, islset2))
+        libisl.isl_set_free(islset1)
+        libisl.isl_set_free(islset2)
+        return equal
+
+    def isdisjoint(self, other):
+        symbols = self._xsymbols([self, other])
+        islset1 = self._toislset(self.polyhedra, symbols)
+        islset2 = self._toislset(other.polyhedra, symbols)
+        equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
+        libisl.isl_set_free(islset1)
+        libisl.isl_set_free(islset2)
+        return equal
+
+    def issubset(self, other):
+        symbols = self._xsymbols([self, other])
+        islset1 = self._toislset(self.polyhedra, symbols)
+        islset2 = self._toislset(other.polyhedra, symbols)
+        equal = bool(libisl.isl_set_is_subset(islset1, islset2))
+        libisl.isl_set_free(islset1)
+        libisl.isl_set_free(islset2)
+        return equal
+
+    def __le__(self, other):
+        return self.issubset(other)
+
+    def __lt__(self, other):
+        symbols = self._xsymbols([self, other])
+        islset1 = self._toislset(self.polyhedra, symbols)
+        islset2 = self._toislset(other.polyhedra, symbols)
+        equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
+        libisl.isl_set_free(islset1)
+        libisl.isl_set_free(islset2)
+        return equal
+
+    def complement(self):
+        islset = self._toislset(self.polyhedra, self.symbols)
+        islset = libisl.isl_set_complement(islset)
+        return self._fromislset(islset, self.symbols)
+
+    def __invert__(self):
+        return self.complement()
+
+    def simplify(self):
+        # see isl_set_coalesce, isl_set_detect_equalities,
+        # isl_set_remove_redundancies
+        # which ones? in which order?
+        raise NotImplementedError
+
+    def polyhedral_hull(self):
+        # several types of hull are available
+        # polyhedral seems to be the more appropriate, to be checked
+        from .polyhedra import Polyhedron
+        islset = self._toislset(self.polyhedra, self.symbols)
+        islbset = libisl.isl_set_polyhedral_hull(islset)
+        return Polyhedron._fromislbasicset(islbset, self.symbols)
+
+    def project(self, symbols):
+        # not sure what isl_set_project_out actually does…
+        # use isl_set_drop_constraints_involving_dims instead?
+        raise NotImplementedError
+
+    def sample(self):
+        from .polyhedra import Polyhedron
+        islset = self._toislset(self.polyhedra, self.symbols)
+        islbset = libisl.isl_set_sample(islset)
+        return Polyhedron._fromislbasicset(islbset, self.symbols)
+
+    def intersection(self, *others):
+        if len(others) == 0:
+            return self
+        symbols = self._xsymbols((self,) + others)
+        islset1 = self._toislset(self.polyhedra, symbols)
+        for other in others:
+            islset2 = other._toislset(other.polyhedra, symbols)
+            islset1 = libisl.isl_set_intersect(islset1, islset2)
+        return self._fromislset(islset1, symbols)
+
+    def __and__(self, other):
+        return self.intersection(other)
+
+    def union(self, *others):
+        if len(others) == 0:
+            return self
+        symbols = self._xsymbols((self,) + others)
+        islset1 = self._toislset(self.polyhedra, symbols)
+        for other in others:
+            islset2 = other._toislset(other.polyhedra, symbols)
+            islset1 = libisl.isl_set_union(islset1, islset2)
+        return self._fromislset(islset1, symbols)
+
+    def __or__(self, other):
+        return self.union(other)
+
+    def __add__(self, other):
+        return self.union(other)
+
+    def difference(self, other):
+        symbols = self._xsymbols([self, other])
+        islset1 = self._toislset(self.polyhedra, symbols)
+        islset2 = other._toislset(other.polyhedra, symbols)
+        islset = libisl.isl_set_subtract(islset1, islset2)
+        return self._fromislset(islset, symbols)
+
+    def __sub__(self, other):
+        return self.difference(other)
+
+    def lexmin(self):
+        islset = self._toislset(self.polyhedra, self.symbols)
+        islset = libisl.isl_set_lexmin(islset)
+        return self._fromislset(islset, self.symbols)
+
+    def lexmax(self):
+        islset = self._toislset(self.polyhedra, self.symbols)
+        islset = libisl.isl_set_lexmax(islset)
+        return self._fromislset(islset, self.symbols)
+
+    @classmethod
+    def _fromislset(cls, islset, symbols):
+        from .polyhedra import Polyhedron
+        islset = libisl.isl_set_remove_divs(islset)
+        islbsets = isl_set_basic_sets(islset)
+        libisl.isl_set_free(islset)
+        polyhedra = []
+        for islbset in islbsets:
+            polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
+            polyhedra.append(polyhedron)
+        if len(polyhedra) == 0:
+            from .polyhedra import Empty
+            return Empty
+        elif len(polyhedra) == 1:
+            return polyhedra[0]
+        else:
+            self = object().__new__(Domain)
+            self._polyhedra = tuple(polyhedra)
+            self._symbols = cls._xsymbols(polyhedra)
+            self._dimension = len(self._symbols)
+            return self
+
+    def _toislset(cls, polyhedra, symbols):
+        polyhedron = polyhedra[0]
+        islbset = polyhedron._toislbasicset(polyhedron.equalities,
+            polyhedron.inequalities, symbols)
+        islset1 = libisl.isl_set_from_basic_set(islbset)
+        for polyhedron in polyhedra[1:]:
+            islbset = polyhedron._toislbasicset(polyhedron.equalities,
+                polyhedron.inequalities, symbols)
+            islset2 = libisl.isl_set_from_basic_set(islbset)
+            islset1 = libisl.isl_set_union(islset1, islset2)
+        return islset1
+
+    @classmethod
+    def fromstring(cls, string):
+        raise NotImplementedError
+
+    def __repr__(self):
+        assert len(self.polyhedra) >= 2
+        strings = [repr(polyhedron) for polyhedron in self.polyhedra]
+        return 'Or({})'.format(', '.join(strings))
+
+    @classmethod
+    def fromsympy(cls, expr):
+        raise NotImplementedError
+
+    def tosympy(self):
+        raise NotImplementedError
+
+
+def And(*domains):
+    if len(domains) == 0:
+        from .polyhedra import Universe
+        return Universe
+    else:
+        return domains[0].intersection(*domains[1:])
+
+def Or(*domains):
+    if len(domains) == 0:
+        from .polyhedra import Empty
+        return Empty
+    else:
+        return domains[0].union(*domains[1:])
+
+def Not(domain):
+    return ~domain
diff --git a/pypol/isl.py b/pypol/isl.py
deleted file mode 100644 (file)
index 32ce305..0000000
+++ /dev/null
@@ -1,136 +0,0 @@
-import ctypes, ctypes.util
-
-from . import _isl
-
-
-__all__ = [
-    'Context',
-    'BasicSet',
-]
-
-
-libisl = ctypes.CDLL(ctypes.util.find_library('isl'))
-
-libisl.isl_printer_get_str.restype = ctypes.c_char_p
-libisl.isl_dim_set = _isl.dim_set
-
-
-class IslObject:
-
-    __slots__ = (
-        '_ptr',
-    )
-
-    def __init__(self, ptr):
-        self._ptr = ptr
-
-    @property
-    def _as_parameter_(self):
-        return self._ptr
-
-
-class Context(IslObject):
-
-    def __init__(self):
-        ptr = libisl.isl_ctx_alloc()
-        super().__init__(ptr)
-
-    #comment out so does not delete itself after being created
-    #def __del__(self):
-    #   libisl.isl_ctx_free(self)
-
-    def __eq__(self, other):
-        if not isinstance(other, Context):
-            return False
-        return self._ptr == other._ptr
-
-
-class BasicSet(IslObject):
-
-    def __str__(self):
-        ls = libisl.isl_basic_set_get_local_space(self)
-        ctx = libisl.isl_local_space_get_ctx(ls)
-        p = libisl.isl_printer_to_str(ctx)
-        p = libisl.isl_printer_print_basic_set(p, self)
-        string = libisl.isl_printer_get_str(p).decode()
-        return string
-
-    def __del__(self):
-        libisl.isl_basic_set_free(self)
-
-    def constraints(self):
-        return _isl.basic_set_constraints(self._ptr)
-        
-    def _fromisl(self, cls, symbols):
-        constraints = self.constraints()
-        equalities = []
-        inequalities = []
-        co = []
-        eq_string = ""
-        in_string = ""
-        string = ""
-        for constraint in constraints:
-            ls = libisl.isl_basic_set_get_local_space(self)
-            ctx = libisl.isl_local_space_get_ctx(ls)
-            p = libisl.isl_printer_to_str(ctx)
-            if libisl.isl_constraint_is_equality(constraint): #check if equality
-                constant = libisl.isl_constraint_get_constant_val(constraint)
-                const = libisl.isl_printer_print_val(p, constant)
-                const = libisl.isl_printer_get_str(const).decode()
-                const = int(const)
-                libisl.isl_printer_free(p)
-                for symbol in symbols:
-                    p = libisl.isl_printer_to_str(ctx)
-                    dim = symbols.index(symbol)
-                    coefficient = libisl.isl_constraint_get_coefficient_val(constraint, libisl.isl_dim_set, dim)
-                    coeff = libisl.isl_printer_print_val(p, coefficient)
-                    coeff = libisl.isl_printer_get_str(coeff).decode()
-                    coeff = int(coeff)
-                    if coeff!=0:
-                        co.append('{}{}'.format(coeff, symbols[dim])) 
-                        for value in co:
-                            string += '{}+'.format(value)
-                        equalities.append('{}{}==0'.format(string, const))
-                        co = []
-                        string = ''                     
-                    libisl.isl_printer_free(p)
-            else: #same for inequality
-                constant = libisl.isl_constraint_get_constant_val(constraint)
-                const = libisl.isl_printer_print_val(p, constant)
-                const = libisl.isl_printer_get_str(const).decode()
-                const = int(const)
-                libisl.isl_printer_free(p)
-                for symbol in symbols:
-                    p = libisl.isl_printer_to_str(ctx)
-                    dim = symbols.index(symbol)
-                    coefficient = libisl.isl_constraint_get_coefficient_val(constraint, libisl.isl_dim_set, dim)
-                    coeff = libisl.isl_printer_print_val(p, coefficient)
-                    coeff = libisl.isl_printer_get_str(coeff).decode()
-                    coeff = int(coeff)
-                    if coeff!=0:
-                        co.append('{}{}'.format(coeff, symbols[dim])) 
-                        for value in co:
-                            string += '{} + '.format(value)
-                        inequalities.append('{}{} <= 0'.format(string, const))
-                        co = []
-                        string = ""
-                    libisl.isl_printer_free(p)
-                    
-        for equations in equalities:
-            eq_string += ' {}'.format(equations)
-            eq_strings = eq_string.split()
-        print(eq_strings)
-        
-        for equations in inequalities:
-            in_string += ', {}'.format(equations)
-        print(in_string)
-        if eq_string and in_string:
-            final = '{}, {}'.format(eq_string, in_string)
-        elif eq_string != '':
-            final = '{}'.format(eq_strings)
-        elif in_string != '' :
-            final = '{}'.format(in_string)
-            
-                               
-        return ('{}({!r})'.format(cls.__name__,final))  
-        
diff --git a/pypol/islhelper.py b/pypol/islhelper.py
new file mode 100644 (file)
index 0000000..75d90d0
--- /dev/null
@@ -0,0 +1,41 @@
+import ctypes, ctypes.util
+
+from . import _islhelper
+from ._islhelper import isl_basic_set_constraints, isl_set_basic_sets
+
+
+__all__ = [
+    'libisl',
+    'mainctx',
+    'isl_val_to_int',
+    'isl_basic_set_to_str', 'isl_basic_set_constraints',
+    'isl_set_to_str', 'isl_set_basic_sets',
+]
+
+
+libisl = ctypes.CDLL(ctypes.util.find_library('isl'))
+
+libisl.isl_printer_get_str.restype = ctypes.c_char_p
+libisl.isl_dim_set = _islhelper.dim_set
+
+
+mainctx = libisl.isl_ctx_alloc()
+
+
+def isl_val_to_int(islval):
+    islpr = libisl.isl_printer_to_str(mainctx)
+    islpr = libisl.isl_printer_print_val(islpr, islval)
+    string = libisl.isl_printer_get_str(islpr).decode()
+    return int(string)
+
+def isl_basic_set_to_str(islbset):
+    islpr = libisl.isl_printer_to_str(mainctx)
+    islpr = libisl.isl_printer_print_basic_set(islpr, islbset)
+    string = libisl.isl_printer_get_str(islpr).decode()
+    return string
+
+def isl_set_to_str(islset):
+    islpr = libisl.isl_printer_to_str(mainctx)
+    islpr = libisl.isl_printer_print_set(islpr, islset)
+    string = libisl.isl_printer_get_str(islpr).decode()
+    return string
diff --git a/pypol/linear.py b/pypol/linear.py
deleted file mode 100644 (file)
index b40415f..0000000
+++ /dev/null
@@ -1,784 +0,0 @@
-import ast
-import functools
-import numbers
-import re
-
-from fractions import Fraction, gcd
-
-from . import isl
-from .isl import libisl
-
-
-__all__ = [
-    'Expression', 'Constant', 'Symbol', 'symbols',
-    'eq', 'le', 'lt', 'ge', 'gt',
-    'Polyhedron',
-    'Empty', 'Universe'
-]
-
-
-def _polymorphic_method(func):
-    @functools.wraps(func)
-    def wrapper(a, b):
-        if isinstance(b, Expression):
-            return func(a, b)
-        if isinstance(b, numbers.Rational):
-            b = Constant(b)
-            return func(a, b)
-        return NotImplemented
-    return wrapper
-
-def _polymorphic_operator(func):
-    # A polymorphic operator should call a polymorphic method, hence we just
-    # have to test the left operand.
-    @functools.wraps(func)
-    def wrapper(a, b):
-        if isinstance(a, numbers.Rational):
-            a = Constant(a)
-            return func(a, b)
-        elif isinstance(a, Expression):
-            return func(a, b)
-        raise TypeError('arguments must be linear expressions')
-    return wrapper
-
-
-_main_ctx = isl.Context()
-
-
-class Expression:
-    """
-    This class implements linear expressions.
-    """
-
-    __slots__ = (
-        '_coefficients',
-        '_constant',
-        '_symbols',
-        '_dimension',
-    )
-
-    def __new__(cls, coefficients=None, constant=0):
-        if isinstance(coefficients, str):
-            if constant:
-                raise TypeError('too many arguments')
-            return cls.fromstring(coefficients)
-        if isinstance(coefficients, dict):
-            coefficients = coefficients.items()
-        if coefficients is None:
-            return Constant(constant)
-        coefficients = [(symbol, coefficient)
-                for symbol, coefficient in coefficients if coefficient != 0]
-        if len(coefficients) == 0:
-            return Constant(constant)
-        elif len(coefficients) == 1 and constant == 0:
-            symbol, coefficient = coefficients[0]
-            if coefficient == 1:
-                return Symbol(symbol)
-        self = object().__new__(cls)
-        self._coefficients = {}
-        for symbol, coefficient in coefficients:
-            if isinstance(symbol, Symbol):
-                symbol = symbol.name
-            elif not isinstance(symbol, str):
-                raise TypeError('symbols must be strings or Symbol instances')
-            if isinstance(coefficient, Constant):
-                coefficient = coefficient.constant
-            if not isinstance(coefficient, numbers.Rational):
-                raise TypeError('coefficients must be rational numbers or Constant instances')
-            self._coefficients[symbol] = coefficient
-        if isinstance(constant, Constant):
-            constant = constant.constant
-        if not isinstance(constant, numbers.Rational):
-            raise TypeError('constant must be a rational number or a Constant instance')
-        self._constant = constant
-        self._symbols = tuple(sorted(self._coefficients))
-        self._dimension = len(self._symbols)
-        return self
-
-    @classmethod
-    def _fromast(cls, node):
-        if isinstance(node, ast.Module) and len(node.body) == 1:
-            return cls._fromast(node.body[0])
-        elif isinstance(node, ast.Expr):
-            return cls._fromast(node.value)
-        elif isinstance(node, ast.Name):
-            return Symbol(node.id)
-        elif isinstance(node, ast.Num):
-            return Constant(node.n)
-        elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
-            return -cls._fromast(node.operand)
-        elif isinstance(node, ast.BinOp):
-            left = cls._fromast(node.left)
-            right = cls._fromast(node.right)
-            if isinstance(node.op, ast.Add):
-                return left + right
-            elif isinstance(node.op, ast.Sub):
-                return left - right
-            elif isinstance(node.op, ast.Mult):
-                return left * right
-            elif isinstance(node.op, ast.Div):
-                return left / right
-        raise SyntaxError('invalid syntax')
-
-    @classmethod
-    def fromstring(cls, string):
-        string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
-        tree = ast.parse(string, 'eval')
-        return cls._fromast(tree)
-
-    @property
-    def symbols(self):
-        return self._symbols
-
-    @property
-    def dimension(self):
-        return self._dimension
-
-    def coefficient(self, symbol):
-        if isinstance(symbol, Symbol):
-            symbol = str(symbol)
-        elif not isinstance(symbol, str):
-            raise TypeError('symbol must be a string or a Symbol instance')
-        try:
-            return self._coefficients[symbol]
-        except KeyError:
-            return 0
-
-    __getitem__ = coefficient
-
-    def coefficients(self):
-        for symbol in self.symbols:
-            yield symbol, self.coefficient(symbol)
-
-    @property
-    def constant(self):
-        return self._constant
-
-    def isconstant(self):
-        return False
-
-    def values(self):
-        for symbol in self.symbols:
-            yield self.coefficient(symbol)
-        yield self.constant
-
-    def issymbol(self):
-        return False
-
-    def __bool__(self):
-        return True
-
-    def __pos__(self):
-        return self
-
-    def __neg__(self):
-        return self * -1
-
-    @_polymorphic_method
-    def __add__(self, other):
-        coefficients = dict(self.coefficients())
-        for symbol, coefficient in other.coefficients():
-            if symbol in coefficients:
-                coefficients[symbol] += coefficient
-            else:
-                coefficients[symbol] = coefficient
-        constant = self.constant + other.constant
-        return Expression(coefficients, constant)
-
-    __radd__ = __add__
-
-    @_polymorphic_method
-    def __sub__(self, other):
-        coefficients = dict(self.coefficients())
-        for symbol, coefficient in other.coefficients():
-            if symbol in coefficients:
-                coefficients[symbol] -= coefficient
-            else:
-                coefficients[symbol] = -coefficient
-        constant = self.constant - other.constant
-        return Expression(coefficients, constant)
-
-    def __rsub__(self, other):
-        return -(self - other)
-
-    @_polymorphic_method
-    def __mul__(self, other):
-        if other.isconstant():
-            coefficients = dict(self.coefficients())
-            for symbol in coefficients:
-                coefficients[symbol] *= other.constant
-            constant = self.constant * other.constant
-            return Expression(coefficients, constant)
-        if isinstance(other, Expression) and not self.isconstant():
-            raise ValueError('non-linear expression: '
-                    '{} * {}'.format(self._parenstr(), other._parenstr()))
-        return NotImplemented
-
-    __rmul__ = __mul__
-
-    @_polymorphic_method
-    def __truediv__(self, other):
-        if other.isconstant():
-            coefficients = dict(self.coefficients())
-            for symbol in coefficients:
-                coefficients[symbol] = \
-                        Fraction(coefficients[symbol], other.constant)
-            constant = Fraction(self.constant, other.constant)
-            return Expression(coefficients, constant)
-        if isinstance(other, Expression):
-            raise ValueError('non-linear expression: '
-                '{} / {}'.format(self._parenstr(), other._parenstr()))
-        return NotImplemented
-
-    def __rtruediv__(self, other):
-        if isinstance(other, self):
-            if self.isconstant():
-                constant = Fraction(other, self.constant)
-                return Expression(constant=constant)
-            else:
-                raise ValueError('non-linear expression: '
-                        '{} / {}'.format(other._parenstr(), self._parenstr()))
-        return NotImplemented
-
-    def __str__(self):
-        string = ''
-        i = 0
-        for symbol in self.symbols:
-            coefficient = self.coefficient(symbol)
-            if coefficient == 1:
-                if i == 0:
-                    string += symbol
-                else:
-                    string += ' + {}'.format(symbol)
-            elif coefficient == -1:
-                if i == 0:
-                    string += '-{}'.format(symbol)
-                else:
-                    string += ' - {}'.format(symbol)
-            else:
-                if i == 0:
-                    string += '{}*{}'.format(coefficient, symbol)
-                elif coefficient > 0:
-                    string += ' + {}*{}'.format(coefficient, symbol)
-                else:
-                    assert coefficient < 0
-                    coefficient *= -1
-                    string += ' - {}*{}'.format(coefficient, symbol)
-            i += 1
-        constant = self.constant
-        if constant != 0 and i == 0:
-            string += '{}'.format(constant)
-        elif constant > 0:
-            string += ' + {}'.format(constant)
-        elif constant < 0:
-            constant *= -1
-            string += ' - {}'.format(constant)
-        if string == '':
-            string = '0'
-        return string
-
-    def _parenstr(self, always=False):
-        string = str(self)
-        if not always and (self.isconstant() or self.issymbol()):
-            return string
-        else:
-            return '({})'.format(string)
-
-    def __repr__(self):
-        return '{}({!r})'.format(self.__class__.__name__, str(self))
-
-    @_polymorphic_method
-    def __eq__(self, other):
-        # "normal" equality
-        # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
-        return isinstance(other, Expression) and \
-                self._coefficients == other._coefficients and \
-                self.constant == other.constant
-
-    def __hash__(self):
-        return hash((tuple(sorted(self._coefficients.items())), self._constant))
-
-    def _toint(self):
-        lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
-                [value.denominator for value in self.values()])
-        return self * lcm
-
-    @_polymorphic_method
-    def _eq(self, other):
-        return Polyhedron(equalities=[(self - other)._toint()])
-
-    @_polymorphic_method
-    def __le__(self, other):
-        return Polyhedron(inequalities=[(other - self)._toint()])
-
-    @_polymorphic_method
-    def __lt__(self, other):
-        return Polyhedron(inequalities=[(other - self)._toint() - 1])
-
-    @_polymorphic_method
-    def __ge__(self, other):
-        return Polyhedron(inequalities=[(self - other)._toint()])
-
-    @_polymorphic_method
-    def __gt__(self, other):
-        return Polyhedron(inequalities=[(self - other)._toint() - 1])
-
-    @classmethod
-    def fromsympy(cls, expr):
-        import sympy
-        coefficients = {}
-        constant = 0
-        for symbol, coefficient in expr.as_coefficients_dict().items():
-            coefficient = Fraction(coefficient.p, coefficient.q)
-            if symbol == sympy.S.One:
-                constant = coefficient
-            elif isinstance(symbol, sympy.Symbol):
-                symbol = symbol.name
-                coefficients[symbol] = coefficient
-            else:
-                raise ValueError('non-linear expression: {!r}'.format(expr))
-        return cls(coefficients, constant)
-
-    def tosympy(self):
-        import sympy
-        expr = 0
-        for symbol, coefficient in self.coefficients():
-            term = coefficient * sympy.Symbol(symbol)
-            expr += term
-        expr += self.constant
-        return expr
-
-
-class Constant(Expression):
-
-    def __new__(cls, numerator=0, denominator=None):
-        self = object().__new__(cls)
-        if denominator is None:
-            if isinstance(numerator, numbers.Rational):
-                self._constant = numerator
-            elif isinstance(numerator, Constant):
-                self._constant = numerator.constant
-            else:
-                raise TypeError('constant must be a rational number or a Constant instance')
-        else:
-            self._constant = Fraction(numerator, denominator)
-        self._coefficients = {}
-        self._symbols = ()
-        self._dimension = 0
-        return self
-
-    def isconstant(self):
-        return True
-
-    def __bool__(self):
-        return bool(self.constant)
-
-    def __repr__(self):
-        if self.constant.denominator == 1:
-            return '{}({!r})'.format(self.__class__.__name__, self.constant)
-        else:
-            return '{}({!r}, {!r})'.format(self.__class__.__name__,
-                self.constant.numerator, self.constant.denominator)
-
-    @classmethod
-    def fromsympy(cls, expr):
-        import sympy
-        if isinstance(expr, sympy.Rational):
-            return cls(expr.p, expr.q)
-        elif isinstance(expr, numbers.Rational):
-            return cls(expr)
-        else:
-            raise TypeError('expr must be a sympy.Rational instance')
-
-
-class Symbol(Expression):
-
-    __slots__ = Expression.__slots__ + (
-        '_name',
-    )
-
-    def __new__(cls, name):
-        if isinstance(name, Symbol):
-            name = name.name
-        elif not isinstance(name, str):
-            raise TypeError('name must be a string or a Symbol instance')
-        self = object().__new__(cls)
-        self._coefficients = {name: 1}
-        self._constant = 0
-        self._symbols = tuple(name)
-        self._name = name
-        self._dimension = 1
-        return self
-
-    @property
-    def name(self):
-        return self._name
-
-    def issymbol(self):
-        return True
-
-    def __repr__(self):
-        return '{}({!r})'.format(self.__class__.__name__, self._name)
-
-    @classmethod
-    def fromsympy(cls, expr):
-        import sympy
-        if isinstance(expr, sympy.Symbol):
-            return cls(expr.name)
-        else:
-            raise TypeError('expr must be a sympy.Symbol instance')
-
-
-def symbols(names):
-    if isinstance(names, str):
-        names = names.replace(',', ' ').split()
-    return (Symbol(name) for name in names)
-
-
-@_polymorphic_operator
-def eq(a, b):
-    return a.__eq__(b)
-
-@_polymorphic_operator
-def le(a, b):
-    return a.__le__(b)
-
-@_polymorphic_operator
-def lt(a, b):
-    return a.__lt__(b)
-
-@_polymorphic_operator
-def ge(a, b):
-    return a.__ge__(b)
-
-@_polymorphic_operator
-def gt(a, b):
-    return a.__gt__(b)
-
-
-class Polyhedron:
-    """
-    This class implements polyhedrons.
-    """
-
-    __slots__ = (
-        '_equalities',
-        '_inequalities',
-        '_constraints',
-        '_symbols',
-    )
-
-    def __new__(cls, equalities=None, inequalities=None):
-        if isinstance(equalities, str):
-            if inequalities is not None:
-                raise TypeError('too many arguments')
-            return cls.fromstring(equalities)
-        self = super().__new__(cls)
-        self._equalities = []
-        if equalities is not None:
-            for constraint in equalities:
-                for value in constraint.values():
-                    if value.denominator != 1:
-                        raise TypeError('non-integer constraint: '
-                                '{} == 0'.format(constraint))
-                self._equalities.append(constraint)
-        self._equalities = tuple(self._equalities)
-        self._inequalities = []
-        if inequalities is not None:
-            for constraint in inequalities:
-                for value in constraint.values():
-                    if value.denominator != 1:
-                        raise TypeError('non-integer constraint: '
-                                '{} <= 0'.format(constraint))
-                self._inequalities.append(constraint)
-        self._inequalities = tuple(self._inequalities)
-        self._constraints = self._equalities + self._inequalities
-        self._symbols = set()
-        for constraint in self._constraints:
-            self.symbols.update(constraint.symbols)
-        self._symbols = tuple(sorted(self._symbols))
-        return self
-
-    @classmethod
-    def _fromast(cls, node):
-        if isinstance(node, ast.Module) and len(node.body) == 1:
-            return cls._fromast(node.body[0])
-        elif isinstance(node, ast.Expr):
-            return cls._fromast(node.value)
-        elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd):
-            equalities1, inequalities1 = cls._fromast(node.left)
-            equalities2, inequalities2 = cls._fromast(node.right)
-            equalities = equalities1 + equalities2
-            inequalities = inequalities1 + inequalities2
-            return equalities, inequalities
-        elif isinstance(node, ast.Compare):
-            equalities = []
-            inequalities = []
-            left = Expression._fromast(node.left)
-            for i in range(len(node.ops)):
-                op = node.ops[i]
-                right = Expression._fromast(node.comparators[i])
-                if isinstance(op, ast.Lt):
-                    inequalities.append(right - left - 1)
-                elif isinstance(op, ast.LtE):
-                    inequalities.append(right - left)
-                elif isinstance(op, ast.Eq):
-                    equalities.append(left - right)
-                elif isinstance(op, ast.GtE):
-                    inequalities.append(left - right)
-                elif isinstance(op, ast.Gt):
-                    inequalities.append(left - right - 1)
-                else:
-                    break
-                left = right
-            else:
-                return equalities, inequalities
-        raise SyntaxError('invalid syntax')
-
-    @classmethod
-    def fromstring(cls, string):
-        string = string.strip()
-        string = re.sub(r'^\{\s*|\s*\}$', '', string)
-        string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
-        string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
-        tokens = re.split(r',|;|and|&&|/\\|∧', string, flags=re.I)
-        tokens = ['({})'.format(token) for token in tokens]
-        string = ' & '.join(tokens)
-        tree = ast.parse(string, 'eval')
-        equalities, inequalities = cls._fromast(tree)
-        return cls(equalities, inequalities)
-
-    @property
-    def equalities(self):
-        return self._equalities
-
-    @property
-    def inequalities(self):
-        return self._inequalities
-
-    @property
-    def constraints(self):
-        return self._constraints
-
-    @property
-    def symbols(self):
-        return self._symbols
-
-    @property
-    def dimension(self):
-        return len(self.symbols)
-
-    def __bool__(self):
-        return not self.is_empty()
-
-    def __contains__(self, value):
-        # is the value in the polyhedron?
-        raise NotImplementedError
-
-    def __eq__(self, other):
-        # works correctly when symbols is not passed
-        # should be equal if values are the same even if symbols are different
-        bset = self._toisl()
-        other = other._toisl()
-        return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
-
-    def isempty(self):
-        bset = self._toisl()
-        return bool(libisl.isl_basic_set_is_empty(bset))
-
-    def isuniverse(self):
-        bset = self._toisl()
-        return bool(libisl.isl_basic_set_is_universe(bset))
-
-    def isdisjoint(self, other):
-        # return true if the polyhedron has no elements in common with other
-        #symbols = self._symbolunion(other)
-        bset = self._toisl()
-        other = other._toisl()
-        return bool(libisl.isl_set_is_disjoint(bset, other))
-
-    def issubset(self, other):
-        # check if self(bset) is a subset of other
-        symbols = self._symbolunion(other)
-        bset = self._toisl(symbols)
-        other = other._toisl(symbols)
-        return bool(libisl.isl_set_is_strict_subset(other, bset))
-
-    def __le__(self, other):
-        return self.issubset(other)
-
-    def __lt__(self, other):
-        symbols = self._symbolunion(other)
-        bset = self._toisl(symbols)
-        other = other._toisl(symbols)
-        return bool(libisl.isl_set_is_strict_subset(other, bset))
-
-    def issuperset(self, other):
-        # test whether every element in other is in the polyhedron
-        raise NotImplementedError
-
-    def __ge__(self, other):
-        return self.issuperset(other)
-
-    def __gt__(self, other):
-        symbols = self._symbolunion(other)
-        bset = self._toisl(symbols)
-        other = other._toisl(symbols)
-        bool(libisl.isl_set_is_strict_subset(other, bset))
-        raise NotImplementedError
-
-    def union(self, *others):
-        # return a new polyhedron with elements from the polyhedron and all
-        # others (convex union)
-        raise NotImplementedError
-
-    def __or__(self, other):
-        return self.union(other)
-
-    def intersection(self, *others):
-        # return a new polyhedron with elements common to the polyhedron and all
-        # others
-        # a poor man's implementation could be:
-        # equalities = list(self.equalities)
-        # inequalities = list(self.inequalities)
-        # for other in others:
-        #     equalities.extend(other.equalities)
-        #     inequalities.extend(other.inequalities)
-        # return self.__class__(equalities, inequalities)
-        raise NotImplementedError
-
-    def __and__(self, other):
-        return self.intersection(other)
-
-    def difference(self, other):
-        # return a new polyhedron with elements in the polyhedron that are not in the other
-        symbols = self._symbolunion(other)
-        bset = self._toisl(symbols)
-        other = other._toisl(symbols)
-        difference = libisl.isl_set_subtract(bset, other)
-        return difference
-
-    def __sub__(self, other):
-        return self.difference(other)
-
-    def __str__(self):
-        constraints = []
-        for constraint in self.equalities:
-            constraints.append('{} == 0'.format(constraint))
-        for constraint in self.inequalities:
-            constraints.append('{} >= 0'.format(constraint))
-        return '{}'.format(', '.join(constraints))
-
-    def __repr__(self):
-        if self.isempty():
-            return 'Empty'
-        elif self.isuniverse():
-            return 'Universe'
-        else:
-            return '{}({!r})'.format(self.__class__.__name__, str(self))
-
-    @classmethod
-    def _fromsympy(cls, expr):
-        import sympy
-        equalities = []
-        inequalities = []
-        if expr.func == sympy.And:
-            for arg in expr.args:
-                arg_eqs, arg_ins = cls._fromsympy(arg)
-                equalities.extend(arg_eqs)
-                inequalities.extend(arg_ins)
-        elif expr.func == sympy.Eq:
-            expr = Expression.fromsympy(expr.args[0] - expr.args[1])
-            equalities.append(expr)
-        else:
-            if expr.func == sympy.Lt:
-                expr = Expression.fromsympy(expr.args[1] - expr.args[0] - 1)
-            elif expr.func == sympy.Le:
-                expr = Expression.fromsympy(expr.args[1] - expr.args[0])
-            elif expr.func == sympy.Ge:
-                expr = Expression.fromsympy(expr.args[0] - expr.args[1])
-            elif expr.func == sympy.Gt:
-                expr = Expression.fromsympy(expr.args[0] - expr.args[1] - 1)
-            else:
-                raise ValueError('non-polyhedral expression: {!r}'.format(expr))
-            inequalities.append(expr)
-        return equalities, inequalities
-
-    @classmethod
-    def fromsympy(cls, expr):
-        import sympy
-        equalities, inequalities = cls._fromsympy(expr)
-        return cls(equalities, inequalities)
-
-    def tosympy(self):
-        import sympy
-        constraints = []
-        for equality in self.equalities:
-            constraints.append(sympy.Eq(equality.tosympy(), 0))
-        for inequality in self.inequalities:
-            constraints.append(sympy.Ge(inequality.tosympy(), 0))
-        return sympy.And(*constraints)
-
-    def _symbolunion(self, *others):
-        symbols = set(self.symbols)
-        for other in others:
-            symbols.update(other.symbols)
-        return sorted(symbols)
-
-    def _toisl(self, symbols=None):
-        if symbols is None:
-            symbols = self.symbols
-        dimension = len(symbols)
-        space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
-        bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
-        ls = libisl.isl_local_space_from_space(space)
-        for equality in self.equalities:
-            ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
-            for symbol, coefficient in equality.coefficients():
-                val = str(coefficient).encode()
-                val = libisl.isl_val_read_from_str(_main_ctx, val)
-                dim = symbols.index(symbol)
-                ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
-            if equality.constant != 0:
-                val = str(equality.constant).encode()
-                val = libisl.isl_val_read_from_str(_main_ctx, val)
-                ceq = libisl.isl_constraint_set_constant_val(ceq, val)
-            bset = libisl.isl_basic_set_add_constraint(bset, ceq)
-        for inequality in self.inequalities:
-            cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
-            for symbol, coefficient in inequality.coefficients():
-                val = str(coefficient).encode()
-                val = libisl.isl_val_read_from_str(_main_ctx, val)
-                dim = symbols.index(symbol)
-                cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
-            if inequality.constant != 0:
-                val = str(inequality.constant).encode()
-                val = libisl.isl_val_read_from_str(_main_ctx, val)
-                cin = libisl.isl_constraint_set_constant_val(cin, val)
-            bset = libisl.isl_basic_set_add_constraint(bset, cin)
-        bset = isl.BasicSet(bset)
-        return bset
-
-    @classmethod
-    def _fromisl(cls, bset, symbols):
-        raise NotImplementedError
-        equalities = ...
-        inequalities = ...
-        return cls(equalities, inequalities)
-        '''takes basic set  in isl form and puts back into python version of polyhedron
-        isl example code gives isl form as:
-            "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
-            our printer is giving form as:
-            { [i0, i1] : 2i1 >= -2 - i0 } '''
-
-Empty = eq(0,1)
-
-Universe = Polyhedron()
-
-
-if __name__ == '__main__':
-    #p = Polyhedron('2a + 2b + 1 == 0') # empty
-    p = Polyhedron('3x + 2y + 3 == 0, y == 0') # not empty
-    ip = p._toisl()
-    print(ip)
-    print(ip.constraints())
diff --git a/pypol/linexprs.py b/pypol/linexprs.py
new file mode 100644 (file)
index 0000000..0db7edd
--- /dev/null
@@ -0,0 +1,431 @@
+import ast
+import functools
+import numbers
+import re
+
+from fractions import Fraction, gcd
+
+
+__all__ = [
+    'Expression',
+    'Symbol', 'symbols',
+    'Constant',
+]
+
+
+def _polymorphic(func):
+    @functools.wraps(func)
+    def wrapper(left, right):
+        if isinstance(right, Expression):
+            return func(left, right)
+        elif isinstance(right, numbers.Rational):
+            right = Constant(right)
+            return func(left, right)
+        return NotImplemented
+    return wrapper
+
+
+class Expression:
+    """
+    This class implements linear expressions.
+    """
+
+    __slots__ = (
+        '_coefficients',
+        '_constant',
+        '_symbols',
+        '_dimension',
+    )
+
+    def __new__(cls, coefficients=None, constant=0):
+        if isinstance(coefficients, str):
+            if constant:
+                raise TypeError('too many arguments')
+            return cls.fromstring(coefficients)
+        if isinstance(coefficients, dict):
+            coefficients = coefficients.items()
+        if coefficients is None:
+            return Constant(constant)
+        coefficients = [(symbol, coefficient)
+            for symbol, coefficient in coefficients if coefficient != 0]
+        if len(coefficients) == 0:
+            return Constant(constant)
+        elif len(coefficients) == 1 and constant == 0:
+            symbol, coefficient = coefficients[0]
+            if coefficient == 1:
+                return Symbol(symbol)
+        self = object().__new__(cls)
+        self._coefficients = {}
+        for symbol, coefficient in coefficients:
+            if isinstance(symbol, Symbol):
+                symbol = symbol.name
+            elif not isinstance(symbol, str):
+                raise TypeError('symbols must be strings or Symbol instances')
+            if isinstance(coefficient, Constant):
+                coefficient = coefficient.constant
+            if not isinstance(coefficient, numbers.Rational):
+                raise TypeError('coefficients must be rational numbers '
+                    'or Constant instances')
+            self._coefficients[symbol] = coefficient
+        if isinstance(constant, Constant):
+            constant = constant.constant
+        if not isinstance(constant, numbers.Rational):
+            raise TypeError('constant must be a rational number '
+                'or a Constant instance')
+        self._constant = constant
+        self._symbols = tuple(sorted(self._coefficients))
+        self._dimension = len(self._symbols)
+        return self
+
+    def coefficient(self, symbol):
+        if isinstance(symbol, Symbol):
+            symbol = str(symbol)
+        elif not isinstance(symbol, str):
+            raise TypeError('symbol must be a string or a Symbol instance')
+        try:
+            return self._coefficients[symbol]
+        except KeyError:
+            return 0
+
+    __getitem__ = coefficient
+
+    def coefficients(self):
+        for symbol in self.symbols:
+            yield symbol, self.coefficient(symbol)
+
+    @property
+    def constant(self):
+        return self._constant
+
+    @property
+    def symbols(self):
+        return self._symbols
+
+    @property
+    def dimension(self):
+        return self._dimension
+
+    def isconstant(self):
+        return False
+
+    def issymbol(self):
+        return False
+
+    def values(self):
+        for symbol in self.symbols:
+            yield self.coefficient(symbol)
+        yield self.constant
+
+    def __bool__(self):
+        return True
+
+    def __pos__(self):
+        return self
+
+    def __neg__(self):
+        return self * -1
+
+    @_polymorphic
+    def __add__(self, other):
+        coefficients = dict(self.coefficients())
+        for symbol, coefficient in other.coefficients():
+            if symbol in coefficients:
+                coefficients[symbol] += coefficient
+            else:
+                coefficients[symbol] = coefficient
+        constant = self.constant + other.constant
+        return Expression(coefficients, constant)
+
+    __radd__ = __add__
+
+    @_polymorphic
+    def __sub__(self, other):
+        coefficients = dict(self.coefficients())
+        for symbol, coefficient in other.coefficients():
+            if symbol in coefficients:
+                coefficients[symbol] -= coefficient
+            else:
+                coefficients[symbol] = -coefficient
+        constant = self.constant - other.constant
+        return Expression(coefficients, constant)
+
+    def __rsub__(self, other):
+        return -(self - other)
+
+    @_polymorphic
+    def __mul__(self, other):
+        if other.isconstant():
+            coefficients = dict(self.coefficients())
+            for symbol in coefficients:
+                coefficients[symbol] *= other.constant
+            constant = self.constant * other.constant
+            return Expression(coefficients, constant)
+        if isinstance(other, Expression) and not self.isconstant():
+            raise ValueError('non-linear expression: '
+                    '{} * {}'.format(self._parenstr(), other._parenstr()))
+        return NotImplemented
+
+    __rmul__ = __mul__
+
+    @_polymorphic
+    def __truediv__(self, other):
+        if other.isconstant():
+            coefficients = dict(self.coefficients())
+            for symbol in coefficients:
+                coefficients[symbol] = \
+                        Fraction(coefficients[symbol], other.constant)
+            constant = Fraction(self.constant, other.constant)
+            return Expression(coefficients, constant)
+        if isinstance(other, Expression):
+            raise ValueError('non-linear expression: '
+                '{} / {}'.format(self._parenstr(), other._parenstr()))
+        return NotImplemented
+
+    def __rtruediv__(self, other):
+        if isinstance(other, self):
+            if self.isconstant():
+                constant = Fraction(other, self.constant)
+                return Expression(constant=constant)
+            else:
+                raise ValueError('non-linear expression: '
+                        '{} / {}'.format(other._parenstr(), self._parenstr()))
+        return NotImplemented
+
+    @_polymorphic
+    def __eq__(self, other):
+        # "normal" equality
+        # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
+        return isinstance(other, Expression) and \
+                self._coefficients == other._coefficients and \
+                self.constant == other.constant
+
+    @_polymorphic
+    def __le__(self, other):
+        from .polyhedra import Le
+        return Le(self, other)
+
+    @_polymorphic
+    def __lt__(self, other):
+        from .polyhedra import Lt
+        return Lt(self, other)
+
+    @_polymorphic
+    def __ge__(self, other):
+        from .polyhedra import Ge
+        return Ge(self, other)
+
+    @_polymorphic
+    def __gt__(self, other):
+        from .polyhedra import Gt
+        return Gt(self, other)
+
+    def __hash__(self):
+        return hash((tuple(sorted(self._coefficients.items())), self._constant))
+
+    def _toint(self):
+        lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
+            [value.denominator for value in self.values()])
+        return self * lcm
+
+    @classmethod
+    def _fromast(cls, node):
+        if isinstance(node, ast.Module) and len(node.body) == 1:
+            return cls._fromast(node.body[0])
+        elif isinstance(node, ast.Expr):
+            return cls._fromast(node.value)
+        elif isinstance(node, ast.Name):
+            return Symbol(node.id)
+        elif isinstance(node, ast.Num):
+            return Constant(node.n)
+        elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
+            return -cls._fromast(node.operand)
+        elif isinstance(node, ast.BinOp):
+            left = cls._fromast(node.left)
+            right = cls._fromast(node.right)
+            if isinstance(node.op, ast.Add):
+                return left + right
+            elif isinstance(node.op, ast.Sub):
+                return left - right
+            elif isinstance(node.op, ast.Mult):
+                return left * right
+            elif isinstance(node.op, ast.Div):
+                return left / right
+        raise SyntaxError('invalid syntax')
+
+    @classmethod
+    def fromstring(cls, string):
+        string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
+        tree = ast.parse(string, 'eval')
+        return cls._fromast(tree)
+
+    def __str__(self):
+        string = ''
+        i = 0
+        for symbol in self.symbols:
+            coefficient = self.coefficient(symbol)
+            if coefficient == 1:
+                if i == 0:
+                    string += symbol
+                else:
+                    string += ' + {}'.format(symbol)
+            elif coefficient == -1:
+                if i == 0:
+                    string += '-{}'.format(symbol)
+                else:
+                    string += ' - {}'.format(symbol)
+            else:
+                if i == 0:
+                    string += '{}*{}'.format(coefficient, symbol)
+                elif coefficient > 0:
+                    string += ' + {}*{}'.format(coefficient, symbol)
+                else:
+                    assert coefficient < 0
+                    coefficient *= -1
+                    string += ' - {}*{}'.format(coefficient, symbol)
+            i += 1
+        constant = self.constant
+        if constant != 0 and i == 0:
+            string += '{}'.format(constant)
+        elif constant > 0:
+            string += ' + {}'.format(constant)
+        elif constant < 0:
+            constant *= -1
+            string += ' - {}'.format(constant)
+        if string == '':
+            string = '0'
+        return string
+
+    def _parenstr(self, always=False):
+        string = str(self)
+        if not always and (self.isconstant() or self.issymbol()):
+            return string
+        else:
+            return '({})'.format(string)
+
+    def __repr__(self):
+        return '{}({!r})'.format(self.__class__.__name__, str(self))
+
+    @classmethod
+    def fromsympy(cls, expr):
+        import sympy
+        coefficients = {}
+        constant = 0
+        for symbol, coefficient in expr.as_coefficients_dict().items():
+            coefficient = Fraction(coefficient.p, coefficient.q)
+            if symbol == sympy.S.One:
+                constant = coefficient
+            elif isinstance(symbol, sympy.Symbol):
+                symbol = symbol.name
+                coefficients[symbol] = coefficient
+            else:
+                raise ValueError('non-linear expression: {!r}'.format(expr))
+        return cls(coefficients, constant)
+
+    def tosympy(self):
+        import sympy
+        expr = 0
+        for symbol, coefficient in self.coefficients():
+            term = coefficient * sympy.Symbol(symbol)
+            expr += term
+        expr += self.constant
+        return expr
+
+
+class Symbol(Expression):
+
+    __slots__ = Expression.__slots__ + (
+        '_name',
+    )
+
+    def __new__(cls, name):
+        if isinstance(name, Symbol):
+            name = name.name
+        elif not isinstance(name, str):
+            raise TypeError('name must be a string or a Symbol instance')
+        name = name.strip()
+        self = object().__new__(cls)
+        self._coefficients = {name: 1}
+        self._constant = 0
+        self._symbols = tuple(name)
+        self._name = name
+        self._dimension = 1
+        return self
+
+    @property
+    def name(self):
+        return self._name
+
+    def issymbol(self):
+        return True
+
+    @classmethod
+    def _fromast(cls, node):
+        if isinstance(node, ast.Module) and len(node.body) == 1:
+            return cls._fromast(node.body[0])
+        elif isinstance(node, ast.Expr):
+            return cls._fromast(node.value)
+        elif isinstance(node, ast.Name):
+            return Symbol(node.id)
+        raise SyntaxError('invalid syntax')
+
+    def __repr__(self):
+        return '{}({!r})'.format(self.__class__.__name__, self._name)
+
+    @classmethod
+    def fromsympy(cls, expr):
+        import sympy
+        if isinstance(expr, sympy.Symbol):
+            return cls(expr.name)
+        else:
+            raise TypeError('expr must be a sympy.Symbol instance')
+
+
+def symbols(names):
+    if isinstance(names, str):
+        names = names.replace(',', ' ').split()
+    return (Symbol(name) for name in names)
+
+
+class Constant(Expression):
+
+    def __new__(cls, numerator=0, denominator=None):
+        self = object().__new__(cls)
+        if denominator is None and isinstance(numerator, Constant):
+            self._constant = numerator.constant
+        else:
+            self._constant = Fraction(numerator, denominator)
+        self._coefficients = {}
+        self._symbols = ()
+        self._dimension = 0
+        return self
+
+    def isconstant(self):
+        return True
+
+    def __bool__(self):
+        return self.constant != 0
+
+    @classmethod
+    def fromstring(cls, string):
+        if isinstance(string, str):
+            return Constant(Fraction(string))
+        else:
+            raise TypeError('string must be a string instance')
+
+    def __repr__(self):
+        if self.constant.denominator == 1:
+            return '{}({!r})'.format(self.__class__.__name__,
+                self.constant.numerator)
+        else:
+            return '{}({!r}, {!r})'.format(self.__class__.__name__,
+                self.constant.numerator, self.constant.denominator)
+
+    @classmethod
+    def fromsympy(cls, expr):
+        import sympy
+        if isinstance(expr, sympy.Rational):
+            return cls(expr.p, expr.q)
+        elif isinstance(expr, numbers.Rational):
+            return cls(expr)
+        else:
+            raise TypeError('expr must be a sympy.Rational instance')
diff --git a/pypol/polyhedra.py b/pypol/polyhedra.py
new file mode 100644 (file)
index 0000000..787e965
--- /dev/null
@@ -0,0 +1,304 @@
+import ast
+import functools
+import numbers
+import re
+
+from . import islhelper
+
+from .islhelper import mainctx, libisl
+from .linexprs import Expression, Constant
+from .domains import Domain
+
+
+__all__ = [
+    'Polyhedron',
+    'Lt', 'Le', 'Eq', 'Ne', 'Ge', 'Gt',
+    'Empty', 'Universe',
+]
+
+
+class Polyhedron(Domain):
+
+    __slots__ = (
+        '_equalities',
+        '_inequalities',
+        '_constraints',
+        '_symbols',
+        '_dimension',
+    )
+
+    def __new__(cls, equalities=None, inequalities=None):
+        if isinstance(equalities, str):
+            if inequalities is not None:
+                raise TypeError('too many arguments')
+            return cls.fromstring(equalities)
+        elif isinstance(equalities, Polyhedron):
+            if inequalities is not None:
+                raise TypeError('too many arguments')
+            return equalities
+        elif isinstance(equalities, Domain):
+            if inequalities is not None:
+                raise TypeError('too many arguments')
+            return equalities.polyhedral_hull()
+        if equalities is None:
+            equalities = []
+        else:
+            for i, equality in enumerate(equalities):
+                if not isinstance(equality, Expression):
+                    raise TypeError('equalities must be linear expressions')
+                equalities[i] = equality._toint()
+        if inequalities is None:
+            inequalities = []
+        else:
+            for i, inequality in enumerate(inequalities):
+                if not isinstance(inequality, Expression):
+                    raise TypeError('inequalities must be linear expressions')
+                inequalities[i] = inequality._toint()
+        symbols = cls._xsymbols(equalities + inequalities)
+        islbset = cls._toislbasicset(equalities, inequalities, symbols)
+        return cls._fromislbasicset(islbset, symbols)
+
+    @property
+    def equalities(self):
+        return self._equalities
+
+    @property
+    def inequalities(self):
+        return self._inequalities
+
+    @property
+    def constraints(self):
+        return self._constraints
+
+    @property
+    def polyhedra(self):
+        return self,
+
+    def disjoint(self):
+        return self
+
+    def isuniverse(self):
+        islbset = self._toislbasicset(self.equalities, self.inequalities,
+            self.symbols)
+        universe = bool(libisl.isl_basic_set_is_universe(islbset))
+        libisl.isl_basic_set_free(islbset)
+        return universe
+
+    def polyhedral_hull(self):
+        return self
+
+    @classmethod
+    def _fromislbasicset(cls, islbset, symbols):
+        islconstraints = islhelper.isl_basic_set_constraints(islbset)
+        equalities = []
+        inequalities = []
+        for islconstraint in islconstraints:
+            islpr = libisl.isl_printer_to_str(mainctx)
+            constant = libisl.isl_constraint_get_constant_val(islconstraint)
+            constant = islhelper.isl_val_to_int(constant)
+            coefficients = {}
+            for dim, symbol in enumerate(symbols):
+                coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint, libisl.isl_dim_set, dim)
+                coefficient = islhelper.isl_val_to_int(coefficient)
+                if coefficient != 0:
+                    coefficients[symbol] = coefficient
+            expression = Expression(coefficients, constant)
+            if libisl.isl_constraint_is_equality(islconstraint):
+                equalities.append(expression)
+            else:
+                inequalities.append(expression)
+        libisl.isl_basic_set_free(islbset)
+        self = object().__new__(Polyhedron)
+        self._equalities = tuple(equalities)
+        self._inequalities = tuple(inequalities)
+        self._constraints = tuple(equalities + inequalities)
+        self._symbols = cls._xsymbols(self._constraints)
+        self._dimension = len(self._symbols)
+        return self
+
+    @classmethod
+    def _toislbasicset(cls, equalities, inequalities, symbols):
+        dimension = len(symbols)
+        islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension)
+        islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp))
+        islls = libisl.isl_local_space_from_space(islsp)
+        for equality in equalities:
+            isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls))
+            for symbol, coefficient in equality.coefficients():
+                val = str(coefficient).encode()
+                val = libisl.isl_val_read_from_str(mainctx, val)
+                sid = symbols.index(symbol)
+                isleq = libisl.isl_constraint_set_coefficient_val(isleq,
+                    libisl.isl_dim_set, sid, val)
+            if equality.constant != 0:
+                val = str(equality.constant).encode()
+                val = libisl.isl_val_read_from_str(mainctx, val)
+                isleq = libisl.isl_constraint_set_constant_val(isleq, val)
+            islbset = libisl.isl_basic_set_add_constraint(islbset, isleq)
+        for inequality in inequalities:
+            islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls))
+            for symbol, coefficient in inequality.coefficients():
+                val = str(coefficient).encode()
+                val = libisl.isl_val_read_from_str(mainctx, val)
+                sid = symbols.index(symbol)
+                islin = libisl.isl_constraint_set_coefficient_val(islin,
+                    libisl.isl_dim_set, sid, val)
+            if inequality.constant != 0:
+                val = str(inequality.constant).encode()
+                val = libisl.isl_val_read_from_str(mainctx, val)
+                islin = libisl.isl_constraint_set_constant_val(islin, val)
+            islbset = libisl.isl_basic_set_add_constraint(islbset, islin)
+        return islbset
+
+    @classmethod
+    def _fromast(cls, node):
+        if isinstance(node, ast.Module) and len(node.body) == 1:
+            return cls._fromast(node.body[0])
+        elif isinstance(node, ast.Expr):
+            return cls._fromast(node.value)
+        elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd):
+            equalities1, inequalities1 = cls._fromast(node.left)
+            equalities2, inequalities2 = cls._fromast(node.right)
+            equalities = equalities1 + equalities2
+            inequalities = inequalities1 + inequalities2
+            return equalities, inequalities
+        elif isinstance(node, ast.Compare):
+            equalities = []
+            inequalities = []
+            left = Expression._fromast(node.left)
+            for i in range(len(node.ops)):
+                op = node.ops[i]
+                right = Expression._fromast(node.comparators[i])
+                if isinstance(op, ast.Lt):
+                    inequalities.append(right - left - 1)
+                elif isinstance(op, ast.LtE):
+                    inequalities.append(right - left)
+                elif isinstance(op, ast.Eq):
+                    equalities.append(left - right)
+                elif isinstance(op, ast.GtE):
+                    inequalities.append(left - right)
+                elif isinstance(op, ast.Gt):
+                    inequalities.append(left - right - 1)
+                else:
+                    break
+                left = right
+            else:
+                return equalities, inequalities
+        raise SyntaxError('invalid syntax')
+
+    @classmethod
+    def fromstring(cls, string):
+        string = string.strip()
+        string = re.sub(r'^\{\s*|\s*\}$', '', string)
+        string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
+        string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
+        tokens = re.split(r',|;|and|&&|/\\|∧', string, flags=re.I)
+        tokens = ['({})'.format(token) for token in tokens]
+        string = ' & '.join(tokens)
+        tree = ast.parse(string, 'eval')
+        equalities, inequalities = cls._fromast(tree)
+        return cls(equalities, inequalities)
+
+    def __repr__(self):
+        if self.isempty():
+            return 'Empty'
+        elif self.isuniverse():
+            return 'Universe'
+        else:
+            strings = []
+            for equality in self.equalities:
+                strings.append('Eq({}, 0)'.format(equality))
+            for inequality in self.inequalities:
+                strings.append('Ge({}, 0)'.format(inequality))
+            if len(strings) == 1:
+                return strings[0]
+            else:
+                return 'And({})'.format(', '.join(strings))
+
+    @classmethod
+    def _fromsympy(cls, expr):
+        import sympy
+        equalities = []
+        inequalities = []
+        if expr.func == sympy.And:
+            for arg in expr.args:
+                arg_eqs, arg_ins = cls._fromsympy(arg)
+                equalities.extend(arg_eqs)
+                inequalities.extend(arg_ins)
+        elif expr.func == sympy.Eq:
+            expr = Expression.fromsympy(expr.args[0] - expr.args[1])
+            equalities.append(expr)
+        else:
+            if expr.func == sympy.Lt:
+                expr = Expression.fromsympy(expr.args[1] - expr.args[0] - 1)
+            elif expr.func == sympy.Le:
+                expr = Expression.fromsympy(expr.args[1] - expr.args[0])
+            elif expr.func == sympy.Ge:
+                expr = Expression.fromsympy(expr.args[0] - expr.args[1])
+            elif expr.func == sympy.Gt:
+                expr = Expression.fromsympy(expr.args[0] - expr.args[1] - 1)
+            else:
+                raise ValueError('non-polyhedral expression: {!r}'.format(expr))
+            inequalities.append(expr)
+        return equalities, inequalities
+
+    @classmethod
+    def fromsympy(cls, expr):
+        import sympy
+        equalities, inequalities = cls._fromsympy(expr)
+        return cls(equalities, inequalities)
+
+    def tosympy(self):
+        import sympy
+        constraints = []
+        for equality in self.equalities:
+            constraints.append(sympy.Eq(equality.tosympy(), 0))
+        for inequality in self.inequalities:
+            constraints.append(sympy.Ge(inequality.tosympy(), 0))
+        return sympy.And(*constraints)
+
+
+def _polymorphic(func):
+    @functools.wraps(func)
+    def wrapper(left, right):
+        if isinstance(left, numbers.Rational):
+            left = Constant(left)
+        elif not isinstance(left, Expression):
+            raise TypeError('left must be a a rational number '
+                'or a linear expression')
+        if isinstance(right, numbers.Rational):
+            right = Constant(right)
+        elif not isinstance(right, Expression):
+            raise TypeError('right must be a a rational number '
+                'or a linear expression')
+        return func(left, right)
+    return wrapper
+
+@_polymorphic
+def Lt(left, right):
+    return Polyhedron([], [right - left - 1])
+
+@_polymorphic
+def Le(left, right):
+    return Polyhedron([], [right - left])
+
+@_polymorphic
+def Eq(left, right):
+    return Polyhedron([left - right], [])
+
+@_polymorphic
+def Ne(left, right):
+    return ~Eq(left, right)
+
+@_polymorphic
+def Gt(left, right):
+    return Polyhedron([], [left - right - 1])
+
+@_polymorphic
+def Ge(left, right):
+    return Polyhedron([], [left - right])
+
+
+Empty = Eq(1, 0)
+
+Universe = Polyhedron([])
similarity index 100%
rename from tests/__init__.py
rename to pypol/tests/__init__.py
diff --git a/pypol/tests/test_domains.py b/pypol/tests/test_domains.py
new file mode 100644 (file)
index 0000000..f9e7008
--- /dev/null
@@ -0,0 +1,12 @@
+import unittest
+
+from ..domains import *
+
+
+class TestDomain(unittest.TestCase):
+
+    def setUp(self):
+        pass
+
+    def test_new(self):
+        pass
similarity index 80%
rename from tests/test_linear.py
rename to pypol/tests/test_linexprs.py
index 6cd1ff4..1606ea0 100644 (file)
@@ -3,7 +3,7 @@ import unittest
 
 from fractions import Fraction
 
-from pypol.linear import *
+from ..linexprs import *
 
 
 try:
@@ -32,15 +32,13 @@ class TestExpression(unittest.TestCase):
         self.pi = Expression(constant=Fraction(22, 7))
         self.expr = self.x - 2*self.y + 3
 
-    def test_new_subclass(self):
+    def test_new(self):
         self.assertIsInstance(self.x, Symbol)
         self.assertIsInstance(self.pi, Constant)
         self.assertNotIsInstance(self.x + self.pi, Symbol)
         self.assertNotIsInstance(self.x + self.pi, Constant)
         xx = Expression({'x': 2})
         self.assertNotIsInstance(xx, Symbol)
-
-    def test_new_types(self):
         with self.assertRaises(TypeError):
             Expression('x + y', 2)
         self.assertEqual(Expression({'x': 2}), Expression({self.x: 2}))
@@ -49,18 +47,9 @@ class TestExpression(unittest.TestCase):
         with self.assertRaises(TypeError):
             Expression({'x': '2'})
         self.assertEqual(Expression(constant=1), Expression(constant=self.one))
-        with self.assertRaises(TypeError):
-            Expression(constant='1')
-
-    def test_symbols(self):
-        self.assertCountEqual(self.x.symbols, ['x'])
-        self.assertCountEqual(self.pi.symbols, [])
-        self.assertCountEqual(self.expr.symbols, ['x', 'y'])
-
-    def test_dimension(self):
-        self.assertEqual(self.x.dimension, 1)
-        self.assertEqual(self.pi.dimension, 0)
-        self.assertEqual(self.expr.dimension, 2)
+        self.assertEqual(Expression(constant='1'), Expression(constant=self.one))
+        with self.assertRaises(ValueError):
+            Expression(constant='a')
 
     def test_coefficient(self):
         self.assertEqual(self.expr.coefficient('x'), 1)
@@ -90,19 +79,29 @@ class TestExpression(unittest.TestCase):
         self.assertEqual(self.pi.constant, Fraction(22, 7))
         self.assertEqual(self.expr.constant, 3)
 
+    def test_symbols(self):
+        self.assertCountEqual(self.x.symbols, ['x'])
+        self.assertCountEqual(self.pi.symbols, [])
+        self.assertCountEqual(self.expr.symbols, ['x', 'y'])
+
+    def test_dimension(self):
+        self.assertEqual(self.x.dimension, 1)
+        self.assertEqual(self.pi.dimension, 0)
+        self.assertEqual(self.expr.dimension, 2)
+
     def test_isconstant(self):
         self.assertFalse(self.x.isconstant())
         self.assertTrue(self.pi.isconstant())
         self.assertFalse(self.expr.isconstant())
 
-    def test_values(self):
-        self.assertCountEqual(self.expr.values(), [1, -2, 3])
-
     def test_issymbol(self):
         self.assertTrue(self.x.issymbol())
         self.assertFalse(self.pi.issymbol())
         self.assertFalse(self.expr.issymbol())
 
+    def test_values(self):
+        self.assertCountEqual(self.expr.values(), [1, -2, 3])
+
     def test_bool(self):
         self.assertTrue(self.x)
         self.assertFalse(self.zero)
@@ -132,11 +131,28 @@ class TestExpression(unittest.TestCase):
         self.assertEqual(0 * self.expr, 0)
         self.assertEqual(self.expr * 2, 2*self.x - 4*self.y + 6)
 
-    def test_div(self):
+    def test_truediv(self):
         with self.assertRaises(ZeroDivisionError):
             self.expr / 0
         self.assertEqual(self.expr / 2, self.x / 2 - self.y + Fraction(3, 2))
 
+    def test_eq(self):
+        self.assertEqual(self.expr, self.expr)
+        self.assertNotEqual(self.x, self.y)
+        self.assertEqual(self.zero, 0)
+
+    def test__toint(self):
+        self.assertEqual((self.x + self.y/2 + self.z/3)._toint(),
+                6*self.x + 3*self.y + 2*self.z)
+
+    def test_fromstring(self):
+        self.assertEqual(Expression.fromstring('x'), self.x)
+        self.assertEqual(Expression.fromstring('-x'), -self.x)
+        self.assertEqual(Expression.fromstring('22/7'), self.pi)
+        self.assertEqual(Expression.fromstring('x - 2y + 3'), self.expr)
+        self.assertEqual(Expression.fromstring('x - (3-1)y + 3'), self.expr)
+        self.assertEqual(Expression.fromstring('x - 2*y + 3'), self.expr)
+
     def test_str(self):
         self.assertEqual(str(Expression()), '0')
         self.assertEqual(str(self.x), 'x')
@@ -151,23 +167,6 @@ class TestExpression(unittest.TestCase):
         self.assertEqual(repr(self.x + self.one), "Expression('x + 1')")
         self.assertEqual(repr(self.expr), "Expression('x - 2*y + 3')")
 
-    def test_fromstring(self):
-        self.assertEqual(Expression.fromstring('x'), self.x)
-        self.assertEqual(Expression.fromstring('-x'), -self.x)
-        self.assertEqual(Expression.fromstring('22/7'), self.pi)
-        self.assertEqual(Expression.fromstring('x - 2y + 3'), self.expr)
-        self.assertEqual(Expression.fromstring('x - (3-1)y + 3'), self.expr)
-        self.assertEqual(Expression.fromstring('x - 2*y + 3'), self.expr)
-
-    def test_eq(self):
-        self.assertEqual(self.expr, self.expr)
-        self.assertNotEqual(self.x, self.y)
-        self.assertEqual(self.zero, 0)
-
-    def test__toint(self):
-        self.assertEqual((self.x + self.y/2 + self.z/3)._toint(),
-                6*self.x + 3*self.y + 2*self.z)
-
     @_requires_sympy
     def test_fromsympy(self):
         sp_x, sp_y = sympy.symbols('x y')
@@ -185,33 +184,34 @@ class TestExpression(unittest.TestCase):
         self.assertEqual(self.expr.tosympy(), sp_x - 2*sp_y + 3)
 
 
-class TestConstant(unittest.TestCase):
-
-    def setUp(self):
-        self.zero = Constant(0)
-        self.one = Constant(1)
-        self.pi = Constant(Fraction(22, 7))
-
-    @_requires_sympy
-    def test_fromsympy(self):
-        self.assertEqual(Constant.fromsympy(sympy.Rational(22, 7)), self.pi)
-        with self.assertRaises(TypeError):
-            Constant.fromsympy(sympy.Symbol('x'))
-
-
 class TestSymbol(unittest.TestCase):
 
     def setUp(self):
         self.x = Symbol('x')
         self.y = Symbol('y')
 
+    def test_new(self):
+        self.assertEqual(Symbol(' x '), self.x)
+        self.assertEqual(Symbol(self.x), self.x)
+        with self.assertRaises(TypeError):
+            Symbol(1)
+
     def test_name(self):
         self.assertEqual(self.x.name, 'x')
 
-    def test_symbols(self):
-        self.assertListEqual(list(symbols('x y')), [self.x, self.y])
-        self.assertListEqual(list(symbols('x,y')), [self.x, self.y])
-        self.assertListEqual(list(symbols(['x', 'y'])), [self.x, self.y])
+    def test_issymbol(self):
+        self.assertTrue(self.x.issymbol())
+
+    def test_fromstring(self):
+        self.assertEqual(Symbol.fromstring('x'), self.x)
+        with self.assertRaises(SyntaxError):
+            Symbol.fromstring('1')
+
+    def test_str(self):
+        self.assertEqual(str(self.x), 'x')
+
+    def test_repr(self):
+        self.assertEqual(repr(self.x), "Symbol('x')")
 
     @_requires_sympy
     def test_fromsympy(self):
@@ -224,75 +224,46 @@ class TestSymbol(unittest.TestCase):
         with self.assertRaises(TypeError):
             Symbol.fromsympy(sp_x*sp_x)
 
-
-class TestOperators(unittest.TestCase):
-
-    pass
+    def test_symbols(self):
+        self.assertListEqual(list(symbols('x y')), [self.x, self.y])
+        self.assertListEqual(list(symbols('x,y')), [self.x, self.y])
+        self.assertListEqual(list(symbols(['x', 'y'])), [self.x, self.y])
 
 
-class TestPolyhedron(unittest.TestCase):
+class TestConstant(unittest.TestCase):
 
     def setUp(self):
-        x, y = symbols('x y')
-        self.square = Polyhedron(inequalities=[x, 1 - x, y, 1 - y])
-
-    def test_symbols(self):
-        self.assertCountEqual(self.square.symbols, ['x', 'y'])
+        self.zero = Constant(0)
+        self.one = Constant(1)
+        self.pi = Constant(Fraction(22, 7))
 
-    def test_dimension(self):
-        self.assertEqual(self.square.dimension, 2)
+    def test_new(self):
+        self.assertEqual(Constant(), self.zero)
+        self.assertEqual(Constant(1), self.one)
+        self.assertEqual(Constant(self.pi), self.pi)
+        self.assertEqual(Constant('22/7'), self.pi)
 
-    def test_str(self):
-        self.assertEqual(str(self.square),
-            'x >= 0, -x + 1 >= 0, y >= 0, -y + 1 >= 0')
+    def test_isconstant(self):
+        self.assertTrue(self.zero.isconstant())
 
-    def test_repr(self):
-        self.assertEqual(repr(self.square),
-            "Polyhedron('x >= 0, -x + 1 >= 0, y >= 0, -y + 1 >= 0')")
+    def test_bool(self):
+        self.assertFalse(self.zero)
+        self.assertTrue(self.pi)
 
     def test_fromstring(self):
-        self.assertEqual(Polyhedron.fromstring('{x >= 0, -x + 1 >= 0, '
-            'y >= 0, -y + 1 >= 0}'), self.square)
-
-    def test_isempty(self):
-        self.assertFalse(self.square.isempty())
+        self.assertEqual(Constant.fromstring('22/7'), self.pi)
+        with self.assertRaises(ValueError):
+            Constant.fromstring('a')
+        with self.assertRaises(TypeError):
+            Constant.fromstring(1)
 
-    def test_isuniverse(self):
-        self.assertFalse(self.square.isuniverse())
+    def test_repr(self):
+        self.assertEqual(repr(self.zero), 'Constant(0)')
+        self.assertEqual(repr(self.one), 'Constant(1)')
+        self.assertEqual(repr(self.pi), 'Constant(22, 7)')
 
-    @unittest.expectedFailure
     @_requires_sympy
     def test_fromsympy(self):
-        sp_x, sp_y = sympy.symbols('x y')
-        self.assertEqual(Polyhedron.fromsympy((sp_x >= 0) & (sp_x <= 1) &
-            (sp_y >= 0) & (sp_y <= 1)), self.square)
-
-    @_requires_sympy
-    def test_tosympy(self):
-        sp_x, sp_y = sympy.symbols('x y')
-        self.assertEqual(self.square.tosympy(),
-            sympy.And(-sp_x + 1 >= 0, -sp_y + 1 >= 0, sp_x >= 0, sp_y >= 0))
-
-
-class TestEmpty:
-
-    def test_repr(self):
-        self.assertEqual(repr(Empty), 'Empty')
-
-    def test_isempty(self):
-        self.assertTrue(Empty.isempty())
-
-    def test_isuniverse(self):
-        self.assertFalse(Empty.isuniverse())
-
-
-class TestUniverse:
-
-    def test_repr(self):
-        self.assertEqual(repr(Universe), 'Universe')
-
-    def test_isempty(self):
-        self.assertTrue(Universe.isempty())
-
-    def test_isuniverse(self):
-        self.assertTrue(Universe.isuniverse())
+        self.assertEqual(Constant.fromsympy(sympy.Rational(22, 7)), self.pi)
+        with self.assertRaises(TypeError):
+            Constant.fromsympy(sympy.Symbol('x'))
diff --git a/pypol/tests/test_polyhedra.py b/pypol/tests/test_polyhedra.py
new file mode 100644 (file)
index 0000000..c74e25f
--- /dev/null
@@ -0,0 +1,87 @@
+import functools
+import unittest
+
+from ..linexprs import symbols
+from ..polyhedra import *
+
+
+try:
+    import sympy
+    def _requires_sympy(func):
+        @functools.wraps(func)
+        def wrapper(self):
+            return func(self)
+        return wrapper
+except ImportError:
+    def _requires_sympy(func):
+        @functools.wraps(func)
+        def wrapper(self):
+            raise unittest.SkipTest('SymPy is not available')
+        return wrapper
+
+
+class TestPolyhedron(unittest.TestCase):
+
+    def setUp(self):
+        x, y = symbols('x y')
+        self.square = Polyhedron(inequalities=[x, 1 - x, y, 1 - y])
+
+    def test_symbols(self):
+        self.assertCountEqual(self.square.symbols, ['x', 'y'])
+
+    def test_dimension(self):
+        self.assertEqual(self.square.dimension, 2)
+
+    def test_str(self):
+        self.assertEqual(str(self.square),
+            'And(Ge(x, 0), Ge(-x + 1, 0), Ge(y, 0), Ge(-y + 1, 0))')
+
+    def test_repr(self):
+        self.assertEqual(repr(self.square),
+            "And(Ge(x, 0), Ge(-x + 1, 0), Ge(y, 0), Ge(-y + 1, 0))")
+
+    def test_fromstring(self):
+        self.assertEqual(Polyhedron.fromstring('{x >= 0, -x + 1 >= 0, '
+            'y >= 0, -y + 1 >= 0}'), self.square)
+
+    def test_isempty(self):
+        self.assertFalse(self.square.isempty())
+
+    def test_isuniverse(self):
+        self.assertFalse(self.square.isuniverse())
+
+    @_requires_sympy
+    def test_fromsympy(self):
+        sp_x, sp_y = sympy.symbols('x y')
+        self.assertEqual(Polyhedron.fromsympy((sp_x >= 0) & (sp_x <= 1) &
+            (sp_y >= 0) & (sp_y <= 1)), self.square)
+
+    @_requires_sympy
+    def test_tosympy(self):
+        sp_x, sp_y = sympy.symbols('x y')
+        self.assertEqual(self.square.tosympy(),
+            sympy.And(-sp_x + 1 >= 0, -sp_y + 1 >= 0, sp_x >= 0, sp_y >= 0))
+
+
+class TestEmpty:
+
+    def test_repr(self):
+        self.assertEqual(repr(Empty), 'Empty')
+
+    def test_isempty(self):
+        self.assertTrue(Empty.isempty())
+
+    def test_isuniverse(self):
+        self.assertFalse(Empty.isuniverse())
+
+
+class TestUniverse:
+
+    def test_repr(self):
+        self.assertEqual(repr(Universe), 'Universe')
+
+    def test_isempty(self):
+        self.assertTrue(Universe.isempty())
+
+    def test_isuniverse(self):
+        self.assertTrue(Universe.isuniverse())
index b530f28..368983e 100755 (executable)
--- a/setup.py
+++ b/setup.py
@@ -8,8 +8,8 @@ setup(
     author='MINES ParisTech',
     packages=['pypol'],
     ext_modules = [
-        Extension('pypol._isl',
-            sources=['pypol/_isl.c'],
+        Extension('pypol._islhelper',
+            sources=['pypol/_islhelper.c'],
             libraries=['isl'])
     ]
 )
diff --git a/tests/test_isl.py b/tests/test_isl.py
deleted file mode 100644 (file)
index 21374b5..0000000
+++ /dev/null
@@ -1,13 +0,0 @@
-import unittest
-
-from math import floor, ceil, trunc
-
-from pypol.isl import *
-
-
-class TestContext(unittest.TestCase):
-
-    def test_eq(self):
-        ctx1, ctx2 = Context(), Context()
-        self.assertEqual(ctx1, ctx1)
-        self.assertNotEqual(ctx1, ctx2)