From 1d494bb187b70135df721c13306d7f26fdf33f50 Mon Sep 17 00:00:00 2001 From: Vivien Maisonneuve Date: Wed, 25 Jun 2014 13:56:34 +0200 Subject: [PATCH] Split linear.py and add domains --- Makefile | 5 +- pypol/.gitignore | 2 +- pypol/__init__.py | 12 +- pypol/_isl.c | 89 -- pypol/_islhelper.c | 148 ++++ pypol/domains.py | 271 ++++++ pypol/isl.py | 136 --- pypol/islhelper.py | 41 + pypol/linear.py | 784 ------------------ pypol/linexprs.py | 431 ++++++++++ pypol/polyhedra.py | 304 +++++++ {tests => pypol/tests}/__init__.py | 0 pypol/tests/test_domains.py | 12 + .../tests/test_linexprs.py | 199 ++--- pypol/tests/test_polyhedra.py | 87 ++ setup.py | 4 +- tests/test_isl.py | 13 - 17 files changed, 1390 insertions(+), 1148 deletions(-) delete mode 100644 pypol/_isl.c create mode 100644 pypol/_islhelper.c create mode 100644 pypol/domains.py delete mode 100644 pypol/isl.py create mode 100644 pypol/islhelper.py delete mode 100644 pypol/linear.py create mode 100644 pypol/linexprs.py create mode 100644 pypol/polyhedra.py rename {tests => pypol/tests}/__init__.py (100%) create mode 100644 pypol/tests/test_domains.py rename tests/test_linear.py => pypol/tests/test_linexprs.py (80%) create mode 100644 pypol/tests/test_polyhedra.py delete mode 100644 tests/test_isl.py diff --git a/Makefile b/Makefile index f2f5e9c..687d36b 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,5 @@ test: build .PHONY: clean clean: - $(RM) build dist MANIFEST venv - $(RM) $(NAME).egg-info $(NAME)/_isl.*.so $(NAME)/__pycache__ - $(RM) tests/__pycache__ + $(RM) build dist MANIFEST venv $(NAME).egg-info + $(RM) $(NAME)/_islhelper.*.so $(NAME)/__pycache__ $(NAME)/tests/__pycache__ diff --git a/pypol/.gitignore b/pypol/.gitignore index 7951d19..fd69600 100644 --- a/pypol/.gitignore +++ b/pypol/.gitignore @@ -1 +1 @@ -/_isl.*.so +/_islhelper.*.so diff --git a/pypol/__init__.py b/pypol/__init__.py index 451d7f1..fc70bea 100644 --- a/pypol/__init__.py +++ b/pypol/__init__.py @@ -2,13 +2,13 @@ A polyhedral library based on ISL. """ -from .linear import Polyhedron, Constant, Symbol, symbols -from .linear import eq, le, lt, ge, gt -from .linear import Empty, Universe +from .linexprs import Expression, Constant, Symbol, symbols +from .polyhedra import Polyhedron, Eq, Ne, Le, Lt, Ge, Gt, Ne, Empty, Universe +from .domains import Domain, And, Or, Not __all__ = [ - 'Polyhedron', 'Constant', 'Symbol', 'symbols', - 'eq', 'le', 'lt', 'ge', 'gt', - 'Empty', 'Universe' + 'Expression', 'Constant', 'Symbol', 'symbols', + 'Polyhedron', 'Eq', 'Ne', 'Le', 'Lt', 'Ge', 'Gt', 'Empty', 'Universe', + 'Domain', 'And', 'Or', 'Not', ] diff --git a/pypol/_isl.c b/pypol/_isl.c deleted file mode 100644 index 1b086dc..0000000 --- a/pypol/_isl.c +++ /dev/null @@ -1,89 +0,0 @@ -#include - -#include -#include - -struct pointer_list { - int cursor; - PyObject *pointers; -}; -typedef struct pointer_list pointer_list; - -static int pointer_list_append_constraint(isl_constraint *c, void *user) { - pointer_list *list; - PyObject *value; - - list = (pointer_list *) user; - value = PyLong_FromVoidPtr(c); - if (value == NULL) { - return -1; - } - PyList_SET_ITEM(list->pointers, list->cursor++, value); - return 0; -} - -static PyObject * basic_set_constraints(PyObject *self, PyObject* args) { - long ptr; - isl_basic_set *bset; - int n; - PyObject *pointers; - pointer_list *list; - - if (!PyArg_ParseTuple(args, "l", &ptr)) - return NULL; - bset = (isl_basic_set*) ptr; - n = isl_basic_set_n_constraint(bset); - if (n == -1) { - PyErr_SetString(PyExc_RuntimeError, - "an error occurred in isl_basic_set_n_constraint"); - return NULL; - } - pointers = PyList_New(n); - if (pointers == NULL) { - return NULL; - } - list = malloc(sizeof(pointer_list)); - if (list == NULL) { - Py_DECREF(pointers); - return PyErr_NoMemory(); - } - list->cursor = 0; - list->pointers = pointers; - n = isl_basic_set_foreach_constraint(bset, pointer_list_append_constraint, - list); - free(list); - if (n == -1) { - PyErr_SetString(PyExc_RuntimeError, - "an error occurred in isl_basic_set_foreach_constraint"); - Py_DECREF(pointers); - return NULL; - } - return pointers; -} - -static PyMethodDef _isl_methods[] = { - {"basic_set_constraints", basic_set_constraints, METH_VARARGS, NULL}, - {NULL, NULL, 0, NULL} -}; - -static struct PyModuleDef _islmodule = { - PyModuleDef_HEAD_INIT, - "_isl", - NULL, - 0, - _isl_methods -}; - -PyMODINIT_FUNC PyInit__isl(void) { - PyObject *m; - m = PyModule_Create(&_islmodule); - if (m == NULL) { - return NULL; - } - - if (PyModule_AddObject(m, "dim_set", PyLong_FromLong(isl_dim_set)) == -1) { - return NULL; - } - - return m; -} diff --git a/pypol/_islhelper.c b/pypol/_islhelper.c new file mode 100644 index 0000000..f8c03e3 --- /dev/null +++ b/pypol/_islhelper.c @@ -0,0 +1,148 @@ +#include + +#include +#include + + +struct pointer_list { + int cursor; + PyObject *pointers; +}; + +typedef struct pointer_list pointer_list; + + +static int pointer_list_append_constraint(isl_constraint *c, void *user) { + pointer_list *list; + PyObject *value; + + list = (pointer_list *) user; + value = PyLong_FromVoidPtr(c); + if (value == NULL) { + return -1; + } + PyList_SET_ITEM(list->pointers, list->cursor++, value); + return 0; +} + +static PyObject * isl_basic_set_constraints(PyObject *self, PyObject* args) { + long ptr; + isl_basic_set *bset; + int n; + PyObject *pointers; + pointer_list *list; + + if (!PyArg_ParseTuple(args, "l", &ptr)) { + return NULL; + } + bset = (isl_basic_set *) ptr; + n = isl_basic_set_n_constraint(bset); + if (n == -1) { + PyErr_SetString(PyExc_RuntimeError, + "an error occurred in isl_basic_set_n_constraint"); + return NULL; + } + pointers = PyList_New(n); + if (pointers == NULL) { + return NULL; + } + list = malloc(sizeof(pointer_list)); + if (list == NULL) { + Py_DECREF(pointers); + return PyErr_NoMemory(); + } + list->cursor = 0; + list->pointers = pointers; + n = isl_basic_set_foreach_constraint(bset, pointer_list_append_constraint, + list); + free(list); + if (n == -1) { + PyErr_SetString(PyExc_RuntimeError, + "an error occurred in isl_basic_set_foreach_constraint"); + Py_DECREF(pointers); + return NULL; + } + return pointers; +} + + +static int pointer_list_append_basic_set(isl_basic_set *bset, void *user) { + pointer_list *list; + PyObject *value; + + list = (pointer_list *) user; + value = PyLong_FromVoidPtr(bset); + if (value == NULL) { + return -1; + } + PyList_SET_ITEM(list->pointers, list->cursor++, value); + return 0; +} + +static PyObject * isl_set_basic_sets(PyObject *self, PyObject *args) { + long ptr; + isl_set *set; + int n; + PyObject *pointers; + pointer_list *list; + + if (!PyArg_ParseTuple(args, "l", &ptr)) { + return NULL; + } + set = (isl_set *) ptr; + n = isl_set_n_basic_set(set); + if (n == -1) { + PyErr_SetString(PyExc_RuntimeError, + "an error occurred in isl_set_n_basic_set"); + return NULL; + } + pointers = PyList_New(n); + if (pointers == NULL) { + return NULL; + } + list = malloc(sizeof(pointer_list)); + if (list == NULL) { + Py_DECREF(pointers); + return PyErr_NoMemory(); + } + list->cursor = 0; + list->pointers = pointers; + n = isl_set_foreach_basic_set(set, pointer_list_append_basic_set, list); + free(list); + if (n == -1) { + PyErr_SetString(PyExc_RuntimeError, + "an error occurred in isl_set_foreach_basic_set"); + Py_DECREF(pointers); + return NULL; + } + return pointers; +} + + +static PyMethodDef _islhelper_methods[] = { + {"isl_basic_set_constraints", isl_basic_set_constraints, METH_VARARGS, NULL}, + {"isl_set_basic_sets", isl_set_basic_sets, METH_VARARGS, NULL}, + {NULL, NULL, 0, NULL} +}; + +static struct PyModuleDef _islhelpermodule = { + PyModuleDef_HEAD_INIT, + "_islhelper", + NULL, + 0, + _islhelper_methods +}; + +PyMODINIT_FUNC PyInit__islhelper(void) { + PyObject *m; + m = PyModule_Create(&_islhelpermodule); + if (m == NULL) { + return NULL; + } + + if (PyModule_AddObject(m, "dim_set", PyLong_FromLong(isl_dim_set)) == -1) { + return NULL; + } + + return m; +} diff --git a/pypol/domains.py b/pypol/domains.py new file mode 100644 index 0000000..fd588b7 --- /dev/null +++ b/pypol/domains.py @@ -0,0 +1,271 @@ +import functools + +from . import islhelper + +from .islhelper import mainctx, libisl, isl_set_basic_sets + + +__all__ = [ + 'Domain', + 'And', 'Or', 'Not', +] + + +@functools.total_ordering +class Domain: + + __slots__ = ( + '_polyhedra', + '_symbols', + '_dimension', + ) + + def __new__(cls, *polyhedra): + from .polyhedra import Polyhedron + if len(polyhedra) == 1: + polyhedron = polyhedra[0] + if isinstance(polyhedron, str): + return cls.fromstring(polyhedron) + elif isinstance(polyhedron, Polyhedron): + return polyhedron + else: + raise TypeError('argument must be a string ' + 'or a Polyhedron instance') + else: + for polyhedron in polyhedra: + if not isinstance(polyhedron, Polyhedron): + raise TypeError('arguments must be Polyhedron instances') + symbols = cls._xsymbols(polyhedra) + islset = cls._toislset(polyhedra, symbols) + return cls._fromislset(islset, symbols) + + @classmethod + def _xsymbols(cls, iterator): + """ + Return the ordered tuple of symbols present in iterator. + """ + symbols = set() + for item in iterator: + symbols.update(item.symbols) + return tuple(sorted(symbols)) + + @property + def polyhedra(self): + return self._polyhedra + + @property + def symbols(self): + return self._symbols + + @property + def dimension(self): + return self._dimension + + def disjoint(self): + islset = self._toislset(self.polyhedra, self.symbols) + islset = libisl.isl_set_make_disjoint(mainctx, islset) + return self._fromislset(islset, self.symbols) + + def isempty(self): + islset = self._toislset(self.polyhedra, self.symbols) + empty = bool(libisl.isl_set_is_empty(islset)) + libisl.isl_set_free(islset) + return empty + + def __bool__(self): + return not self.isempty() + + def isuniverse(self): + islset = self._toislset(self.polyhedra, self.symbols) + universe = bool(libisl.isl_set_plain_is_universe(islset)) + libisl.isl_set_free(islset) + return universe + + def __eq__(self, other): + symbols = self._xsymbols([self, other]) + islset1 = self._toislset(self.polyhedra, symbols) + islset2 = other._toislset(other.polyhedra, symbols) + equal = bool(libisl.isl_set_is_equal(islset1, islset2)) + libisl.isl_set_free(islset1) + libisl.isl_set_free(islset2) + return equal + + def isdisjoint(self, other): + symbols = self._xsymbols([self, other]) + islset1 = self._toislset(self.polyhedra, symbols) + islset2 = self._toislset(other.polyhedra, symbols) + equal = bool(libisl.isl_set_is_disjoint(islset1, islset2)) + libisl.isl_set_free(islset1) + libisl.isl_set_free(islset2) + return equal + + def issubset(self, other): + symbols = self._xsymbols([self, other]) + islset1 = self._toislset(self.polyhedra, symbols) + islset2 = self._toislset(other.polyhedra, symbols) + equal = bool(libisl.isl_set_is_subset(islset1, islset2)) + libisl.isl_set_free(islset1) + libisl.isl_set_free(islset2) + return equal + + def __le__(self, other): + return self.issubset(other) + + def __lt__(self, other): + symbols = self._xsymbols([self, other]) + islset1 = self._toislset(self.polyhedra, symbols) + islset2 = self._toislset(other.polyhedra, symbols) + equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2)) + libisl.isl_set_free(islset1) + libisl.isl_set_free(islset2) + return equal + + def complement(self): + islset = self._toislset(self.polyhedra, self.symbols) + islset = libisl.isl_set_complement(islset) + return self._fromislset(islset, self.symbols) + + def __invert__(self): + return self.complement() + + def simplify(self): + # see isl_set_coalesce, isl_set_detect_equalities, + # isl_set_remove_redundancies + # which ones? in which order? + raise NotImplementedError + + def polyhedral_hull(self): + # several types of hull are available + # polyhedral seems to be the more appropriate, to be checked + from .polyhedra import Polyhedron + islset = self._toislset(self.polyhedra, self.symbols) + islbset = libisl.isl_set_polyhedral_hull(islset) + return Polyhedron._fromislbasicset(islbset, self.symbols) + + def project(self, symbols): + # not sure what isl_set_project_out actually does… + # use isl_set_drop_constraints_involving_dims instead? + raise NotImplementedError + + def sample(self): + from .polyhedra import Polyhedron + islset = self._toislset(self.polyhedra, self.symbols) + islbset = libisl.isl_set_sample(islset) + return Polyhedron._fromislbasicset(islbset, self.symbols) + + def intersection(self, *others): + if len(others) == 0: + return self + symbols = self._xsymbols((self,) + others) + islset1 = self._toislset(self.polyhedra, symbols) + for other in others: + islset2 = other._toislset(other.polyhedra, symbols) + islset1 = libisl.isl_set_intersect(islset1, islset2) + return self._fromislset(islset1, symbols) + + def __and__(self, other): + return self.intersection(other) + + def union(self, *others): + if len(others) == 0: + return self + symbols = self._xsymbols((self,) + others) + islset1 = self._toislset(self.polyhedra, symbols) + for other in others: + islset2 = other._toislset(other.polyhedra, symbols) + islset1 = libisl.isl_set_union(islset1, islset2) + return self._fromislset(islset1, symbols) + + def __or__(self, other): + return self.union(other) + + def __add__(self, other): + return self.union(other) + + def difference(self, other): + symbols = self._xsymbols([self, other]) + islset1 = self._toislset(self.polyhedra, symbols) + islset2 = other._toislset(other.polyhedra, symbols) + islset = libisl.isl_set_subtract(islset1, islset2) + return self._fromislset(islset, symbols) + + def __sub__(self, other): + return self.difference(other) + + def lexmin(self): + islset = self._toislset(self.polyhedra, self.symbols) + islset = libisl.isl_set_lexmin(islset) + return self._fromislset(islset, self.symbols) + + def lexmax(self): + islset = self._toislset(self.polyhedra, self.symbols) + islset = libisl.isl_set_lexmax(islset) + return self._fromislset(islset, self.symbols) + + @classmethod + def _fromislset(cls, islset, symbols): + from .polyhedra import Polyhedron + islset = libisl.isl_set_remove_divs(islset) + islbsets = isl_set_basic_sets(islset) + libisl.isl_set_free(islset) + polyhedra = [] + for islbset in islbsets: + polyhedron = Polyhedron._fromislbasicset(islbset, symbols) + polyhedra.append(polyhedron) + if len(polyhedra) == 0: + from .polyhedra import Empty + return Empty + elif len(polyhedra) == 1: + return polyhedra[0] + else: + self = object().__new__(Domain) + self._polyhedra = tuple(polyhedra) + self._symbols = cls._xsymbols(polyhedra) + self._dimension = len(self._symbols) + return self + + def _toislset(cls, polyhedra, symbols): + polyhedron = polyhedra[0] + islbset = polyhedron._toislbasicset(polyhedron.equalities, + polyhedron.inequalities, symbols) + islset1 = libisl.isl_set_from_basic_set(islbset) + for polyhedron in polyhedra[1:]: + islbset = polyhedron._toislbasicset(polyhedron.equalities, + polyhedron.inequalities, symbols) + islset2 = libisl.isl_set_from_basic_set(islbset) + islset1 = libisl.isl_set_union(islset1, islset2) + return islset1 + + @classmethod + def fromstring(cls, string): + raise NotImplementedError + + def __repr__(self): + assert len(self.polyhedra) >= 2 + strings = [repr(polyhedron) for polyhedron in self.polyhedra] + return 'Or({})'.format(', '.join(strings)) + + @classmethod + def fromsympy(cls, expr): + raise NotImplementedError + + def tosympy(self): + raise NotImplementedError + + +def And(*domains): + if len(domains) == 0: + from .polyhedra import Universe + return Universe + else: + return domains[0].intersection(*domains[1:]) + +def Or(*domains): + if len(domains) == 0: + from .polyhedra import Empty + return Empty + else: + return domains[0].union(*domains[1:]) + +def Not(domain): + return ~domain diff --git a/pypol/isl.py b/pypol/isl.py deleted file mode 100644 index 32ce305..0000000 --- a/pypol/isl.py +++ /dev/null @@ -1,136 +0,0 @@ -import ctypes, ctypes.util - -from . import _isl - - -__all__ = [ - 'Context', - 'BasicSet', -] - - -libisl = ctypes.CDLL(ctypes.util.find_library('isl')) - -libisl.isl_printer_get_str.restype = ctypes.c_char_p -libisl.isl_dim_set = _isl.dim_set - - -class IslObject: - - __slots__ = ( - '_ptr', - ) - - def __init__(self, ptr): - self._ptr = ptr - - @property - def _as_parameter_(self): - return self._ptr - - -class Context(IslObject): - - def __init__(self): - ptr = libisl.isl_ctx_alloc() - super().__init__(ptr) - - #comment out so does not delete itself after being created - #def __del__(self): - # libisl.isl_ctx_free(self) - - def __eq__(self, other): - if not isinstance(other, Context): - return False - return self._ptr == other._ptr - - -class BasicSet(IslObject): - - def __str__(self): - ls = libisl.isl_basic_set_get_local_space(self) - ctx = libisl.isl_local_space_get_ctx(ls) - p = libisl.isl_printer_to_str(ctx) - p = libisl.isl_printer_print_basic_set(p, self) - string = libisl.isl_printer_get_str(p).decode() - return string - - def __del__(self): - libisl.isl_basic_set_free(self) - - def constraints(self): - return _isl.basic_set_constraints(self._ptr) - - def _fromisl(self, cls, symbols): - constraints = self.constraints() - equalities = [] - inequalities = [] - co = [] - eq_string = "" - in_string = "" - string = "" - for constraint in constraints: - ls = libisl.isl_basic_set_get_local_space(self) - ctx = libisl.isl_local_space_get_ctx(ls) - p = libisl.isl_printer_to_str(ctx) - if libisl.isl_constraint_is_equality(constraint): #check if equality - constant = libisl.isl_constraint_get_constant_val(constraint) - const = libisl.isl_printer_print_val(p, constant) - const = libisl.isl_printer_get_str(const).decode() - const = int(const) - libisl.isl_printer_free(p) - for symbol in symbols: - p = libisl.isl_printer_to_str(ctx) - dim = symbols.index(symbol) - coefficient = libisl.isl_constraint_get_coefficient_val(constraint, libisl.isl_dim_set, dim) - coeff = libisl.isl_printer_print_val(p, coefficient) - coeff = libisl.isl_printer_get_str(coeff).decode() - coeff = int(coeff) - if coeff!=0: - co.append('{}{}'.format(coeff, symbols[dim])) - for value in co: - string += '{}+'.format(value) - equalities.append('{}{}==0'.format(string, const)) - co = [] - string = '' - libisl.isl_printer_free(p) - else: #same for inequality - constant = libisl.isl_constraint_get_constant_val(constraint) - const = libisl.isl_printer_print_val(p, constant) - const = libisl.isl_printer_get_str(const).decode() - const = int(const) - libisl.isl_printer_free(p) - for symbol in symbols: - p = libisl.isl_printer_to_str(ctx) - dim = symbols.index(symbol) - coefficient = libisl.isl_constraint_get_coefficient_val(constraint, libisl.isl_dim_set, dim) - coeff = libisl.isl_printer_print_val(p, coefficient) - coeff = libisl.isl_printer_get_str(coeff).decode() - coeff = int(coeff) - if coeff!=0: - co.append('{}{}'.format(coeff, symbols[dim])) - for value in co: - string += '{} + '.format(value) - inequalities.append('{}{} <= 0'.format(string, const)) - co = [] - string = "" - libisl.isl_printer_free(p) - - for equations in equalities: - eq_string += ' {}'.format(equations) - eq_strings = eq_string.split() - print(eq_strings) - - for equations in inequalities: - in_string += ', {}'.format(equations) - print(in_string) - if eq_string and in_string: - final = '{}, {}'.format(eq_string, in_string) - elif eq_string != '': - final = '{}'.format(eq_strings) - elif in_string != '' : - final = '{}'.format(in_string) - - - return ('{}({!r})'.format(cls.__name__,final)) - diff --git a/pypol/islhelper.py b/pypol/islhelper.py new file mode 100644 index 0000000..75d90d0 --- /dev/null +++ b/pypol/islhelper.py @@ -0,0 +1,41 @@ +import ctypes, ctypes.util + +from . import _islhelper +from ._islhelper import isl_basic_set_constraints, isl_set_basic_sets + + +__all__ = [ + 'libisl', + 'mainctx', + 'isl_val_to_int', + 'isl_basic_set_to_str', 'isl_basic_set_constraints', + 'isl_set_to_str', 'isl_set_basic_sets', +] + + +libisl = ctypes.CDLL(ctypes.util.find_library('isl')) + +libisl.isl_printer_get_str.restype = ctypes.c_char_p +libisl.isl_dim_set = _islhelper.dim_set + + +mainctx = libisl.isl_ctx_alloc() + + +def isl_val_to_int(islval): + islpr = libisl.isl_printer_to_str(mainctx) + islpr = libisl.isl_printer_print_val(islpr, islval) + string = libisl.isl_printer_get_str(islpr).decode() + return int(string) + +def isl_basic_set_to_str(islbset): + islpr = libisl.isl_printer_to_str(mainctx) + islpr = libisl.isl_printer_print_basic_set(islpr, islbset) + string = libisl.isl_printer_get_str(islpr).decode() + return string + +def isl_set_to_str(islset): + islpr = libisl.isl_printer_to_str(mainctx) + islpr = libisl.isl_printer_print_set(islpr, islset) + string = libisl.isl_printer_get_str(islpr).decode() + return string diff --git a/pypol/linear.py b/pypol/linear.py deleted file mode 100644 index b40415f..0000000 --- a/pypol/linear.py +++ /dev/null @@ -1,784 +0,0 @@ -import ast -import functools -import numbers -import re - -from fractions import Fraction, gcd - -from . import isl -from .isl import libisl - - -__all__ = [ - 'Expression', 'Constant', 'Symbol', 'symbols', - 'eq', 'le', 'lt', 'ge', 'gt', - 'Polyhedron', - 'Empty', 'Universe' -] - - -def _polymorphic_method(func): - @functools.wraps(func) - def wrapper(a, b): - if isinstance(b, Expression): - return func(a, b) - if isinstance(b, numbers.Rational): - b = Constant(b) - return func(a, b) - return NotImplemented - return wrapper - -def _polymorphic_operator(func): - # A polymorphic operator should call a polymorphic method, hence we just - # have to test the left operand. - @functools.wraps(func) - def wrapper(a, b): - if isinstance(a, numbers.Rational): - a = Constant(a) - return func(a, b) - elif isinstance(a, Expression): - return func(a, b) - raise TypeError('arguments must be linear expressions') - return wrapper - - -_main_ctx = isl.Context() - - -class Expression: - """ - This class implements linear expressions. - """ - - __slots__ = ( - '_coefficients', - '_constant', - '_symbols', - '_dimension', - ) - - def __new__(cls, coefficients=None, constant=0): - if isinstance(coefficients, str): - if constant: - raise TypeError('too many arguments') - return cls.fromstring(coefficients) - if isinstance(coefficients, dict): - coefficients = coefficients.items() - if coefficients is None: - return Constant(constant) - coefficients = [(symbol, coefficient) - for symbol, coefficient in coefficients if coefficient != 0] - if len(coefficients) == 0: - return Constant(constant) - elif len(coefficients) == 1 and constant == 0: - symbol, coefficient = coefficients[0] - if coefficient == 1: - return Symbol(symbol) - self = object().__new__(cls) - self._coefficients = {} - for symbol, coefficient in coefficients: - if isinstance(symbol, Symbol): - symbol = symbol.name - elif not isinstance(symbol, str): - raise TypeError('symbols must be strings or Symbol instances') - if isinstance(coefficient, Constant): - coefficient = coefficient.constant - if not isinstance(coefficient, numbers.Rational): - raise TypeError('coefficients must be rational numbers or Constant instances') - self._coefficients[symbol] = coefficient - if isinstance(constant, Constant): - constant = constant.constant - if not isinstance(constant, numbers.Rational): - raise TypeError('constant must be a rational number or a Constant instance') - self._constant = constant - self._symbols = tuple(sorted(self._coefficients)) - self._dimension = len(self._symbols) - return self - - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.Name): - return Symbol(node.id) - elif isinstance(node, ast.Num): - return Constant(node.n) - elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - return -cls._fromast(node.operand) - elif isinstance(node, ast.BinOp): - left = cls._fromast(node.left) - right = cls._fromast(node.right) - if isinstance(node.op, ast.Add): - return left + right - elif isinstance(node.op, ast.Sub): - return left - right - elif isinstance(node.op, ast.Mult): - return left * right - elif isinstance(node.op, ast.Div): - return left / right - raise SyntaxError('invalid syntax') - - @classmethod - def fromstring(cls, string): - string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string) - tree = ast.parse(string, 'eval') - return cls._fromast(tree) - - @property - def symbols(self): - return self._symbols - - @property - def dimension(self): - return self._dimension - - def coefficient(self, symbol): - if isinstance(symbol, Symbol): - symbol = str(symbol) - elif not isinstance(symbol, str): - raise TypeError('symbol must be a string or a Symbol instance') - try: - return self._coefficients[symbol] - except KeyError: - return 0 - - __getitem__ = coefficient - - def coefficients(self): - for symbol in self.symbols: - yield symbol, self.coefficient(symbol) - - @property - def constant(self): - return self._constant - - def isconstant(self): - return False - - def values(self): - for symbol in self.symbols: - yield self.coefficient(symbol) - yield self.constant - - def issymbol(self): - return False - - def __bool__(self): - return True - - def __pos__(self): - return self - - def __neg__(self): - return self * -1 - - @_polymorphic_method - def __add__(self, other): - coefficients = dict(self.coefficients()) - for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] += coefficient - else: - coefficients[symbol] = coefficient - constant = self.constant + other.constant - return Expression(coefficients, constant) - - __radd__ = __add__ - - @_polymorphic_method - def __sub__(self, other): - coefficients = dict(self.coefficients()) - for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] -= coefficient - else: - coefficients[symbol] = -coefficient - constant = self.constant - other.constant - return Expression(coefficients, constant) - - def __rsub__(self, other): - return -(self - other) - - @_polymorphic_method - def __mul__(self, other): - if other.isconstant(): - coefficients = dict(self.coefficients()) - for symbol in coefficients: - coefficients[symbol] *= other.constant - constant = self.constant * other.constant - return Expression(coefficients, constant) - if isinstance(other, Expression) and not self.isconstant(): - raise ValueError('non-linear expression: ' - '{} * {}'.format(self._parenstr(), other._parenstr())) - return NotImplemented - - __rmul__ = __mul__ - - @_polymorphic_method - def __truediv__(self, other): - if other.isconstant(): - coefficients = dict(self.coefficients()) - for symbol in coefficients: - coefficients[symbol] = \ - Fraction(coefficients[symbol], other.constant) - constant = Fraction(self.constant, other.constant) - return Expression(coefficients, constant) - if isinstance(other, Expression): - raise ValueError('non-linear expression: ' - '{} / {}'.format(self._parenstr(), other._parenstr())) - return NotImplemented - - def __rtruediv__(self, other): - if isinstance(other, self): - if self.isconstant(): - constant = Fraction(other, self.constant) - return Expression(constant=constant) - else: - raise ValueError('non-linear expression: ' - '{} / {}'.format(other._parenstr(), self._parenstr())) - return NotImplemented - - def __str__(self): - string = '' - i = 0 - for symbol in self.symbols: - coefficient = self.coefficient(symbol) - if coefficient == 1: - if i == 0: - string += symbol - else: - string += ' + {}'.format(symbol) - elif coefficient == -1: - if i == 0: - string += '-{}'.format(symbol) - else: - string += ' - {}'.format(symbol) - else: - if i == 0: - string += '{}*{}'.format(coefficient, symbol) - elif coefficient > 0: - string += ' + {}*{}'.format(coefficient, symbol) - else: - assert coefficient < 0 - coefficient *= -1 - string += ' - {}*{}'.format(coefficient, symbol) - i += 1 - constant = self.constant - if constant != 0 and i == 0: - string += '{}'.format(constant) - elif constant > 0: - string += ' + {}'.format(constant) - elif constant < 0: - constant *= -1 - string += ' - {}'.format(constant) - if string == '': - string = '0' - return string - - def _parenstr(self, always=False): - string = str(self) - if not always and (self.isconstant() or self.issymbol()): - return string - else: - return '({})'.format(string) - - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, str(self)) - - @_polymorphic_method - def __eq__(self, other): - # "normal" equality - # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs - return isinstance(other, Expression) and \ - self._coefficients == other._coefficients and \ - self.constant == other.constant - - def __hash__(self): - return hash((tuple(sorted(self._coefficients.items())), self._constant)) - - def _toint(self): - lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), - [value.denominator for value in self.values()]) - return self * lcm - - @_polymorphic_method - def _eq(self, other): - return Polyhedron(equalities=[(self - other)._toint()]) - - @_polymorphic_method - def __le__(self, other): - return Polyhedron(inequalities=[(other - self)._toint()]) - - @_polymorphic_method - def __lt__(self, other): - return Polyhedron(inequalities=[(other - self)._toint() - 1]) - - @_polymorphic_method - def __ge__(self, other): - return Polyhedron(inequalities=[(self - other)._toint()]) - - @_polymorphic_method - def __gt__(self, other): - return Polyhedron(inequalities=[(self - other)._toint() - 1]) - - @classmethod - def fromsympy(cls, expr): - import sympy - coefficients = {} - constant = 0 - for symbol, coefficient in expr.as_coefficients_dict().items(): - coefficient = Fraction(coefficient.p, coefficient.q) - if symbol == sympy.S.One: - constant = coefficient - elif isinstance(symbol, sympy.Symbol): - symbol = symbol.name - coefficients[symbol] = coefficient - else: - raise ValueError('non-linear expression: {!r}'.format(expr)) - return cls(coefficients, constant) - - def tosympy(self): - import sympy - expr = 0 - for symbol, coefficient in self.coefficients(): - term = coefficient * sympy.Symbol(symbol) - expr += term - expr += self.constant - return expr - - -class Constant(Expression): - - def __new__(cls, numerator=0, denominator=None): - self = object().__new__(cls) - if denominator is None: - if isinstance(numerator, numbers.Rational): - self._constant = numerator - elif isinstance(numerator, Constant): - self._constant = numerator.constant - else: - raise TypeError('constant must be a rational number or a Constant instance') - else: - self._constant = Fraction(numerator, denominator) - self._coefficients = {} - self._symbols = () - self._dimension = 0 - return self - - def isconstant(self): - return True - - def __bool__(self): - return bool(self.constant) - - def __repr__(self): - if self.constant.denominator == 1: - return '{}({!r})'.format(self.__class__.__name__, self.constant) - else: - return '{}({!r}, {!r})'.format(self.__class__.__name__, - self.constant.numerator, self.constant.denominator) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Rational): - return cls(expr.p, expr.q) - elif isinstance(expr, numbers.Rational): - return cls(expr) - else: - raise TypeError('expr must be a sympy.Rational instance') - - -class Symbol(Expression): - - __slots__ = Expression.__slots__ + ( - '_name', - ) - - def __new__(cls, name): - if isinstance(name, Symbol): - name = name.name - elif not isinstance(name, str): - raise TypeError('name must be a string or a Symbol instance') - self = object().__new__(cls) - self._coefficients = {name: 1} - self._constant = 0 - self._symbols = tuple(name) - self._name = name - self._dimension = 1 - return self - - @property - def name(self): - return self._name - - def issymbol(self): - return True - - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, self._name) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Symbol): - return cls(expr.name) - else: - raise TypeError('expr must be a sympy.Symbol instance') - - -def symbols(names): - if isinstance(names, str): - names = names.replace(',', ' ').split() - return (Symbol(name) for name in names) - - -@_polymorphic_operator -def eq(a, b): - return a.__eq__(b) - -@_polymorphic_operator -def le(a, b): - return a.__le__(b) - -@_polymorphic_operator -def lt(a, b): - return a.__lt__(b) - -@_polymorphic_operator -def ge(a, b): - return a.__ge__(b) - -@_polymorphic_operator -def gt(a, b): - return a.__gt__(b) - - -class Polyhedron: - """ - This class implements polyhedrons. - """ - - __slots__ = ( - '_equalities', - '_inequalities', - '_constraints', - '_symbols', - ) - - def __new__(cls, equalities=None, inequalities=None): - if isinstance(equalities, str): - if inequalities is not None: - raise TypeError('too many arguments') - return cls.fromstring(equalities) - self = super().__new__(cls) - self._equalities = [] - if equalities is not None: - for constraint in equalities: - for value in constraint.values(): - if value.denominator != 1: - raise TypeError('non-integer constraint: ' - '{} == 0'.format(constraint)) - self._equalities.append(constraint) - self._equalities = tuple(self._equalities) - self._inequalities = [] - if inequalities is not None: - for constraint in inequalities: - for value in constraint.values(): - if value.denominator != 1: - raise TypeError('non-integer constraint: ' - '{} <= 0'.format(constraint)) - self._inequalities.append(constraint) - self._inequalities = tuple(self._inequalities) - self._constraints = self._equalities + self._inequalities - self._symbols = set() - for constraint in self._constraints: - self.symbols.update(constraint.symbols) - self._symbols = tuple(sorted(self._symbols)) - return self - - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd): - equalities1, inequalities1 = cls._fromast(node.left) - equalities2, inequalities2 = cls._fromast(node.right) - equalities = equalities1 + equalities2 - inequalities = inequalities1 + inequalities2 - return equalities, inequalities - elif isinstance(node, ast.Compare): - equalities = [] - inequalities = [] - left = Expression._fromast(node.left) - for i in range(len(node.ops)): - op = node.ops[i] - right = Expression._fromast(node.comparators[i]) - if isinstance(op, ast.Lt): - inequalities.append(right - left - 1) - elif isinstance(op, ast.LtE): - inequalities.append(right - left) - elif isinstance(op, ast.Eq): - equalities.append(left - right) - elif isinstance(op, ast.GtE): - inequalities.append(left - right) - elif isinstance(op, ast.Gt): - inequalities.append(left - right - 1) - else: - break - left = right - else: - return equalities, inequalities - raise SyntaxError('invalid syntax') - - @classmethod - def fromstring(cls, string): - string = string.strip() - string = re.sub(r'^\{\s*|\s*\}$', '', string) - string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string) - string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string) - tokens = re.split(r',|;|and|&&|/\\|∧', string, flags=re.I) - tokens = ['({})'.format(token) for token in tokens] - string = ' & '.join(tokens) - tree = ast.parse(string, 'eval') - equalities, inequalities = cls._fromast(tree) - return cls(equalities, inequalities) - - @property - def equalities(self): - return self._equalities - - @property - def inequalities(self): - return self._inequalities - - @property - def constraints(self): - return self._constraints - - @property - def symbols(self): - return self._symbols - - @property - def dimension(self): - return len(self.symbols) - - def __bool__(self): - return not self.is_empty() - - def __contains__(self, value): - # is the value in the polyhedron? - raise NotImplementedError - - def __eq__(self, other): - # works correctly when symbols is not passed - # should be equal if values are the same even if symbols are different - bset = self._toisl() - other = other._toisl() - return bool(libisl.isl_basic_set_plain_is_equal(bset, other)) - - def isempty(self): - bset = self._toisl() - return bool(libisl.isl_basic_set_is_empty(bset)) - - def isuniverse(self): - bset = self._toisl() - return bool(libisl.isl_basic_set_is_universe(bset)) - - def isdisjoint(self, other): - # return true if the polyhedron has no elements in common with other - #symbols = self._symbolunion(other) - bset = self._toisl() - other = other._toisl() - return bool(libisl.isl_set_is_disjoint(bset, other)) - - def issubset(self, other): - # check if self(bset) is a subset of other - symbols = self._symbolunion(other) - bset = self._toisl(symbols) - other = other._toisl(symbols) - return bool(libisl.isl_set_is_strict_subset(other, bset)) - - def __le__(self, other): - return self.issubset(other) - - def __lt__(self, other): - symbols = self._symbolunion(other) - bset = self._toisl(symbols) - other = other._toisl(symbols) - return bool(libisl.isl_set_is_strict_subset(other, bset)) - - def issuperset(self, other): - # test whether every element in other is in the polyhedron - raise NotImplementedError - - def __ge__(self, other): - return self.issuperset(other) - - def __gt__(self, other): - symbols = self._symbolunion(other) - bset = self._toisl(symbols) - other = other._toisl(symbols) - bool(libisl.isl_set_is_strict_subset(other, bset)) - raise NotImplementedError - - def union(self, *others): - # return a new polyhedron with elements from the polyhedron and all - # others (convex union) - raise NotImplementedError - - def __or__(self, other): - return self.union(other) - - def intersection(self, *others): - # return a new polyhedron with elements common to the polyhedron and all - # others - # a poor man's implementation could be: - # equalities = list(self.equalities) - # inequalities = list(self.inequalities) - # for other in others: - # equalities.extend(other.equalities) - # inequalities.extend(other.inequalities) - # return self.__class__(equalities, inequalities) - raise NotImplementedError - - def __and__(self, other): - return self.intersection(other) - - def difference(self, other): - # return a new polyhedron with elements in the polyhedron that are not in the other - symbols = self._symbolunion(other) - bset = self._toisl(symbols) - other = other._toisl(symbols) - difference = libisl.isl_set_subtract(bset, other) - return difference - - def __sub__(self, other): - return self.difference(other) - - def __str__(self): - constraints = [] - for constraint in self.equalities: - constraints.append('{} == 0'.format(constraint)) - for constraint in self.inequalities: - constraints.append('{} >= 0'.format(constraint)) - return '{}'.format(', '.join(constraints)) - - def __repr__(self): - if self.isempty(): - return 'Empty' - elif self.isuniverse(): - return 'Universe' - else: - return '{}({!r})'.format(self.__class__.__name__, str(self)) - - @classmethod - def _fromsympy(cls, expr): - import sympy - equalities = [] - inequalities = [] - if expr.func == sympy.And: - for arg in expr.args: - arg_eqs, arg_ins = cls._fromsympy(arg) - equalities.extend(arg_eqs) - inequalities.extend(arg_ins) - elif expr.func == sympy.Eq: - expr = Expression.fromsympy(expr.args[0] - expr.args[1]) - equalities.append(expr) - else: - if expr.func == sympy.Lt: - expr = Expression.fromsympy(expr.args[1] - expr.args[0] - 1) - elif expr.func == sympy.Le: - expr = Expression.fromsympy(expr.args[1] - expr.args[0]) - elif expr.func == sympy.Ge: - expr = Expression.fromsympy(expr.args[0] - expr.args[1]) - elif expr.func == sympy.Gt: - expr = Expression.fromsympy(expr.args[0] - expr.args[1] - 1) - else: - raise ValueError('non-polyhedral expression: {!r}'.format(expr)) - inequalities.append(expr) - return equalities, inequalities - - @classmethod - def fromsympy(cls, expr): - import sympy - equalities, inequalities = cls._fromsympy(expr) - return cls(equalities, inequalities) - - def tosympy(self): - import sympy - constraints = [] - for equality in self.equalities: - constraints.append(sympy.Eq(equality.tosympy(), 0)) - for inequality in self.inequalities: - constraints.append(sympy.Ge(inequality.tosympy(), 0)) - return sympy.And(*constraints) - - def _symbolunion(self, *others): - symbols = set(self.symbols) - for other in others: - symbols.update(other.symbols) - return sorted(symbols) - - def _toisl(self, symbols=None): - if symbols is None: - symbols = self.symbols - dimension = len(symbols) - space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension) - bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space)) - ls = libisl.isl_local_space_from_space(space) - for equality in self.equalities: - ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls)) - for symbol, coefficient in equality.coefficients(): - val = str(coefficient).encode() - val = libisl.isl_val_read_from_str(_main_ctx, val) - dim = symbols.index(symbol) - ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val) - if equality.constant != 0: - val = str(equality.constant).encode() - val = libisl.isl_val_read_from_str(_main_ctx, val) - ceq = libisl.isl_constraint_set_constant_val(ceq, val) - bset = libisl.isl_basic_set_add_constraint(bset, ceq) - for inequality in self.inequalities: - cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls)) - for symbol, coefficient in inequality.coefficients(): - val = str(coefficient).encode() - val = libisl.isl_val_read_from_str(_main_ctx, val) - dim = symbols.index(symbol) - cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val) - if inequality.constant != 0: - val = str(inequality.constant).encode() - val = libisl.isl_val_read_from_str(_main_ctx, val) - cin = libisl.isl_constraint_set_constant_val(cin, val) - bset = libisl.isl_basic_set_add_constraint(bset, cin) - bset = isl.BasicSet(bset) - return bset - - @classmethod - def _fromisl(cls, bset, symbols): - raise NotImplementedError - equalities = ... - inequalities = ... - return cls(equalities, inequalities) - '''takes basic set in isl form and puts back into python version of polyhedron - isl example code gives isl form as: - "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}") - our printer is giving form as: - { [i0, i1] : 2i1 >= -2 - i0 } ''' - -Empty = eq(0,1) - -Universe = Polyhedron() - - -if __name__ == '__main__': - #p = Polyhedron('2a + 2b + 1 == 0') # empty - p = Polyhedron('3x + 2y + 3 == 0, y == 0') # not empty - ip = p._toisl() - print(ip) - print(ip.constraints()) diff --git a/pypol/linexprs.py b/pypol/linexprs.py new file mode 100644 index 0000000..0db7edd --- /dev/null +++ b/pypol/linexprs.py @@ -0,0 +1,431 @@ +import ast +import functools +import numbers +import re + +from fractions import Fraction, gcd + + +__all__ = [ + 'Expression', + 'Symbol', 'symbols', + 'Constant', +] + + +def _polymorphic(func): + @functools.wraps(func) + def wrapper(left, right): + if isinstance(right, Expression): + return func(left, right) + elif isinstance(right, numbers.Rational): + right = Constant(right) + return func(left, right) + return NotImplemented + return wrapper + + +class Expression: + """ + This class implements linear expressions. + """ + + __slots__ = ( + '_coefficients', + '_constant', + '_symbols', + '_dimension', + ) + + def __new__(cls, coefficients=None, constant=0): + if isinstance(coefficients, str): + if constant: + raise TypeError('too many arguments') + return cls.fromstring(coefficients) + if isinstance(coefficients, dict): + coefficients = coefficients.items() + if coefficients is None: + return Constant(constant) + coefficients = [(symbol, coefficient) + for symbol, coefficient in coefficients if coefficient != 0] + if len(coefficients) == 0: + return Constant(constant) + elif len(coefficients) == 1 and constant == 0: + symbol, coefficient = coefficients[0] + if coefficient == 1: + return Symbol(symbol) + self = object().__new__(cls) + self._coefficients = {} + for symbol, coefficient in coefficients: + if isinstance(symbol, Symbol): + symbol = symbol.name + elif not isinstance(symbol, str): + raise TypeError('symbols must be strings or Symbol instances') + if isinstance(coefficient, Constant): + coefficient = coefficient.constant + if not isinstance(coefficient, numbers.Rational): + raise TypeError('coefficients must be rational numbers ' + 'or Constant instances') + self._coefficients[symbol] = coefficient + if isinstance(constant, Constant): + constant = constant.constant + if not isinstance(constant, numbers.Rational): + raise TypeError('constant must be a rational number ' + 'or a Constant instance') + self._constant = constant + self._symbols = tuple(sorted(self._coefficients)) + self._dimension = len(self._symbols) + return self + + def coefficient(self, symbol): + if isinstance(symbol, Symbol): + symbol = str(symbol) + elif not isinstance(symbol, str): + raise TypeError('symbol must be a string or a Symbol instance') + try: + return self._coefficients[symbol] + except KeyError: + return 0 + + __getitem__ = coefficient + + def coefficients(self): + for symbol in self.symbols: + yield symbol, self.coefficient(symbol) + + @property + def constant(self): + return self._constant + + @property + def symbols(self): + return self._symbols + + @property + def dimension(self): + return self._dimension + + def isconstant(self): + return False + + def issymbol(self): + return False + + def values(self): + for symbol in self.symbols: + yield self.coefficient(symbol) + yield self.constant + + def __bool__(self): + return True + + def __pos__(self): + return self + + def __neg__(self): + return self * -1 + + @_polymorphic + def __add__(self, other): + coefficients = dict(self.coefficients()) + for symbol, coefficient in other.coefficients(): + if symbol in coefficients: + coefficients[symbol] += coefficient + else: + coefficients[symbol] = coefficient + constant = self.constant + other.constant + return Expression(coefficients, constant) + + __radd__ = __add__ + + @_polymorphic + def __sub__(self, other): + coefficients = dict(self.coefficients()) + for symbol, coefficient in other.coefficients(): + if symbol in coefficients: + coefficients[symbol] -= coefficient + else: + coefficients[symbol] = -coefficient + constant = self.constant - other.constant + return Expression(coefficients, constant) + + def __rsub__(self, other): + return -(self - other) + + @_polymorphic + def __mul__(self, other): + if other.isconstant(): + coefficients = dict(self.coefficients()) + for symbol in coefficients: + coefficients[symbol] *= other.constant + constant = self.constant * other.constant + return Expression(coefficients, constant) + if isinstance(other, Expression) and not self.isconstant(): + raise ValueError('non-linear expression: ' + '{} * {}'.format(self._parenstr(), other._parenstr())) + return NotImplemented + + __rmul__ = __mul__ + + @_polymorphic + def __truediv__(self, other): + if other.isconstant(): + coefficients = dict(self.coefficients()) + for symbol in coefficients: + coefficients[symbol] = \ + Fraction(coefficients[symbol], other.constant) + constant = Fraction(self.constant, other.constant) + return Expression(coefficients, constant) + if isinstance(other, Expression): + raise ValueError('non-linear expression: ' + '{} / {}'.format(self._parenstr(), other._parenstr())) + return NotImplemented + + def __rtruediv__(self, other): + if isinstance(other, self): + if self.isconstant(): + constant = Fraction(other, self.constant) + return Expression(constant=constant) + else: + raise ValueError('non-linear expression: ' + '{} / {}'.format(other._parenstr(), self._parenstr())) + return NotImplemented + + @_polymorphic + def __eq__(self, other): + # "normal" equality + # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs + return isinstance(other, Expression) and \ + self._coefficients == other._coefficients and \ + self.constant == other.constant + + @_polymorphic + def __le__(self, other): + from .polyhedra import Le + return Le(self, other) + + @_polymorphic + def __lt__(self, other): + from .polyhedra import Lt + return Lt(self, other) + + @_polymorphic + def __ge__(self, other): + from .polyhedra import Ge + return Ge(self, other) + + @_polymorphic + def __gt__(self, other): + from .polyhedra import Gt + return Gt(self, other) + + def __hash__(self): + return hash((tuple(sorted(self._coefficients.items())), self._constant)) + + def _toint(self): + lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), + [value.denominator for value in self.values()]) + return self * lcm + + @classmethod + def _fromast(cls, node): + if isinstance(node, ast.Module) and len(node.body) == 1: + return cls._fromast(node.body[0]) + elif isinstance(node, ast.Expr): + return cls._fromast(node.value) + elif isinstance(node, ast.Name): + return Symbol(node.id) + elif isinstance(node, ast.Num): + return Constant(node.n) + elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + return -cls._fromast(node.operand) + elif isinstance(node, ast.BinOp): + left = cls._fromast(node.left) + right = cls._fromast(node.right) + if isinstance(node.op, ast.Add): + return left + right + elif isinstance(node.op, ast.Sub): + return left - right + elif isinstance(node.op, ast.Mult): + return left * right + elif isinstance(node.op, ast.Div): + return left / right + raise SyntaxError('invalid syntax') + + @classmethod + def fromstring(cls, string): + string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string) + tree = ast.parse(string, 'eval') + return cls._fromast(tree) + + def __str__(self): + string = '' + i = 0 + for symbol in self.symbols: + coefficient = self.coefficient(symbol) + if coefficient == 1: + if i == 0: + string += symbol + else: + string += ' + {}'.format(symbol) + elif coefficient == -1: + if i == 0: + string += '-{}'.format(symbol) + else: + string += ' - {}'.format(symbol) + else: + if i == 0: + string += '{}*{}'.format(coefficient, symbol) + elif coefficient > 0: + string += ' + {}*{}'.format(coefficient, symbol) + else: + assert coefficient < 0 + coefficient *= -1 + string += ' - {}*{}'.format(coefficient, symbol) + i += 1 + constant = self.constant + if constant != 0 and i == 0: + string += '{}'.format(constant) + elif constant > 0: + string += ' + {}'.format(constant) + elif constant < 0: + constant *= -1 + string += ' - {}'.format(constant) + if string == '': + string = '0' + return string + + def _parenstr(self, always=False): + string = str(self) + if not always and (self.isconstant() or self.issymbol()): + return string + else: + return '({})'.format(string) + + def __repr__(self): + return '{}({!r})'.format(self.__class__.__name__, str(self)) + + @classmethod + def fromsympy(cls, expr): + import sympy + coefficients = {} + constant = 0 + for symbol, coefficient in expr.as_coefficients_dict().items(): + coefficient = Fraction(coefficient.p, coefficient.q) + if symbol == sympy.S.One: + constant = coefficient + elif isinstance(symbol, sympy.Symbol): + symbol = symbol.name + coefficients[symbol] = coefficient + else: + raise ValueError('non-linear expression: {!r}'.format(expr)) + return cls(coefficients, constant) + + def tosympy(self): + import sympy + expr = 0 + for symbol, coefficient in self.coefficients(): + term = coefficient * sympy.Symbol(symbol) + expr += term + expr += self.constant + return expr + + +class Symbol(Expression): + + __slots__ = Expression.__slots__ + ( + '_name', + ) + + def __new__(cls, name): + if isinstance(name, Symbol): + name = name.name + elif not isinstance(name, str): + raise TypeError('name must be a string or a Symbol instance') + name = name.strip() + self = object().__new__(cls) + self._coefficients = {name: 1} + self._constant = 0 + self._symbols = tuple(name) + self._name = name + self._dimension = 1 + return self + + @property + def name(self): + return self._name + + def issymbol(self): + return True + + @classmethod + def _fromast(cls, node): + if isinstance(node, ast.Module) and len(node.body) == 1: + return cls._fromast(node.body[0]) + elif isinstance(node, ast.Expr): + return cls._fromast(node.value) + elif isinstance(node, ast.Name): + return Symbol(node.id) + raise SyntaxError('invalid syntax') + + def __repr__(self): + return '{}({!r})'.format(self.__class__.__name__, self._name) + + @classmethod + def fromsympy(cls, expr): + import sympy + if isinstance(expr, sympy.Symbol): + return cls(expr.name) + else: + raise TypeError('expr must be a sympy.Symbol instance') + + +def symbols(names): + if isinstance(names, str): + names = names.replace(',', ' ').split() + return (Symbol(name) for name in names) + + +class Constant(Expression): + + def __new__(cls, numerator=0, denominator=None): + self = object().__new__(cls) + if denominator is None and isinstance(numerator, Constant): + self._constant = numerator.constant + else: + self._constant = Fraction(numerator, denominator) + self._coefficients = {} + self._symbols = () + self._dimension = 0 + return self + + def isconstant(self): + return True + + def __bool__(self): + return self.constant != 0 + + @classmethod + def fromstring(cls, string): + if isinstance(string, str): + return Constant(Fraction(string)) + else: + raise TypeError('string must be a string instance') + + def __repr__(self): + if self.constant.denominator == 1: + return '{}({!r})'.format(self.__class__.__name__, + self.constant.numerator) + else: + return '{}({!r}, {!r})'.format(self.__class__.__name__, + self.constant.numerator, self.constant.denominator) + + @classmethod + def fromsympy(cls, expr): + import sympy + if isinstance(expr, sympy.Rational): + return cls(expr.p, expr.q) + elif isinstance(expr, numbers.Rational): + return cls(expr) + else: + raise TypeError('expr must be a sympy.Rational instance') diff --git a/pypol/polyhedra.py b/pypol/polyhedra.py new file mode 100644 index 0000000..787e965 --- /dev/null +++ b/pypol/polyhedra.py @@ -0,0 +1,304 @@ +import ast +import functools +import numbers +import re + +from . import islhelper + +from .islhelper import mainctx, libisl +from .linexprs import Expression, Constant +from .domains import Domain + + +__all__ = [ + 'Polyhedron', + 'Lt', 'Le', 'Eq', 'Ne', 'Ge', 'Gt', + 'Empty', 'Universe', +] + + +class Polyhedron(Domain): + + __slots__ = ( + '_equalities', + '_inequalities', + '_constraints', + '_symbols', + '_dimension', + ) + + def __new__(cls, equalities=None, inequalities=None): + if isinstance(equalities, str): + if inequalities is not None: + raise TypeError('too many arguments') + return cls.fromstring(equalities) + elif isinstance(equalities, Polyhedron): + if inequalities is not None: + raise TypeError('too many arguments') + return equalities + elif isinstance(equalities, Domain): + if inequalities is not None: + raise TypeError('too many arguments') + return equalities.polyhedral_hull() + if equalities is None: + equalities = [] + else: + for i, equality in enumerate(equalities): + if not isinstance(equality, Expression): + raise TypeError('equalities must be linear expressions') + equalities[i] = equality._toint() + if inequalities is None: + inequalities = [] + else: + for i, inequality in enumerate(inequalities): + if not isinstance(inequality, Expression): + raise TypeError('inequalities must be linear expressions') + inequalities[i] = inequality._toint() + symbols = cls._xsymbols(equalities + inequalities) + islbset = cls._toislbasicset(equalities, inequalities, symbols) + return cls._fromislbasicset(islbset, symbols) + + @property + def equalities(self): + return self._equalities + + @property + def inequalities(self): + return self._inequalities + + @property + def constraints(self): + return self._constraints + + @property + def polyhedra(self): + return self, + + def disjoint(self): + return self + + def isuniverse(self): + islbset = self._toislbasicset(self.equalities, self.inequalities, + self.symbols) + universe = bool(libisl.isl_basic_set_is_universe(islbset)) + libisl.isl_basic_set_free(islbset) + return universe + + def polyhedral_hull(self): + return self + + @classmethod + def _fromislbasicset(cls, islbset, symbols): + islconstraints = islhelper.isl_basic_set_constraints(islbset) + equalities = [] + inequalities = [] + for islconstraint in islconstraints: + islpr = libisl.isl_printer_to_str(mainctx) + constant = libisl.isl_constraint_get_constant_val(islconstraint) + constant = islhelper.isl_val_to_int(constant) + coefficients = {} + for dim, symbol in enumerate(symbols): + coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint, libisl.isl_dim_set, dim) + coefficient = islhelper.isl_val_to_int(coefficient) + if coefficient != 0: + coefficients[symbol] = coefficient + expression = Expression(coefficients, constant) + if libisl.isl_constraint_is_equality(islconstraint): + equalities.append(expression) + else: + inequalities.append(expression) + libisl.isl_basic_set_free(islbset) + self = object().__new__(Polyhedron) + self._equalities = tuple(equalities) + self._inequalities = tuple(inequalities) + self._constraints = tuple(equalities + inequalities) + self._symbols = cls._xsymbols(self._constraints) + self._dimension = len(self._symbols) + return self + + @classmethod + def _toislbasicset(cls, equalities, inequalities, symbols): + dimension = len(symbols) + islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension) + islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp)) + islls = libisl.isl_local_space_from_space(islsp) + for equality in equalities: + isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls)) + for symbol, coefficient in equality.coefficients(): + val = str(coefficient).encode() + val = libisl.isl_val_read_from_str(mainctx, val) + sid = symbols.index(symbol) + isleq = libisl.isl_constraint_set_coefficient_val(isleq, + libisl.isl_dim_set, sid, val) + if equality.constant != 0: + val = str(equality.constant).encode() + val = libisl.isl_val_read_from_str(mainctx, val) + isleq = libisl.isl_constraint_set_constant_val(isleq, val) + islbset = libisl.isl_basic_set_add_constraint(islbset, isleq) + for inequality in inequalities: + islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls)) + for symbol, coefficient in inequality.coefficients(): + val = str(coefficient).encode() + val = libisl.isl_val_read_from_str(mainctx, val) + sid = symbols.index(symbol) + islin = libisl.isl_constraint_set_coefficient_val(islin, + libisl.isl_dim_set, sid, val) + if inequality.constant != 0: + val = str(inequality.constant).encode() + val = libisl.isl_val_read_from_str(mainctx, val) + islin = libisl.isl_constraint_set_constant_val(islin, val) + islbset = libisl.isl_basic_set_add_constraint(islbset, islin) + return islbset + + @classmethod + def _fromast(cls, node): + if isinstance(node, ast.Module) and len(node.body) == 1: + return cls._fromast(node.body[0]) + elif isinstance(node, ast.Expr): + return cls._fromast(node.value) + elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd): + equalities1, inequalities1 = cls._fromast(node.left) + equalities2, inequalities2 = cls._fromast(node.right) + equalities = equalities1 + equalities2 + inequalities = inequalities1 + inequalities2 + return equalities, inequalities + elif isinstance(node, ast.Compare): + equalities = [] + inequalities = [] + left = Expression._fromast(node.left) + for i in range(len(node.ops)): + op = node.ops[i] + right = Expression._fromast(node.comparators[i]) + if isinstance(op, ast.Lt): + inequalities.append(right - left - 1) + elif isinstance(op, ast.LtE): + inequalities.append(right - left) + elif isinstance(op, ast.Eq): + equalities.append(left - right) + elif isinstance(op, ast.GtE): + inequalities.append(left - right) + elif isinstance(op, ast.Gt): + inequalities.append(left - right - 1) + else: + break + left = right + else: + return equalities, inequalities + raise SyntaxError('invalid syntax') + + @classmethod + def fromstring(cls, string): + string = string.strip() + string = re.sub(r'^\{\s*|\s*\}$', '', string) + string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string) + string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string) + tokens = re.split(r',|;|and|&&|/\\|∧', string, flags=re.I) + tokens = ['({})'.format(token) for token in tokens] + string = ' & '.join(tokens) + tree = ast.parse(string, 'eval') + equalities, inequalities = cls._fromast(tree) + return cls(equalities, inequalities) + + def __repr__(self): + if self.isempty(): + return 'Empty' + elif self.isuniverse(): + return 'Universe' + else: + strings = [] + for equality in self.equalities: + strings.append('Eq({}, 0)'.format(equality)) + for inequality in self.inequalities: + strings.append('Ge({}, 0)'.format(inequality)) + if len(strings) == 1: + return strings[0] + else: + return 'And({})'.format(', '.join(strings)) + + @classmethod + def _fromsympy(cls, expr): + import sympy + equalities = [] + inequalities = [] + if expr.func == sympy.And: + for arg in expr.args: + arg_eqs, arg_ins = cls._fromsympy(arg) + equalities.extend(arg_eqs) + inequalities.extend(arg_ins) + elif expr.func == sympy.Eq: + expr = Expression.fromsympy(expr.args[0] - expr.args[1]) + equalities.append(expr) + else: + if expr.func == sympy.Lt: + expr = Expression.fromsympy(expr.args[1] - expr.args[0] - 1) + elif expr.func == sympy.Le: + expr = Expression.fromsympy(expr.args[1] - expr.args[0]) + elif expr.func == sympy.Ge: + expr = Expression.fromsympy(expr.args[0] - expr.args[1]) + elif expr.func == sympy.Gt: + expr = Expression.fromsympy(expr.args[0] - expr.args[1] - 1) + else: + raise ValueError('non-polyhedral expression: {!r}'.format(expr)) + inequalities.append(expr) + return equalities, inequalities + + @classmethod + def fromsympy(cls, expr): + import sympy + equalities, inequalities = cls._fromsympy(expr) + return cls(equalities, inequalities) + + def tosympy(self): + import sympy + constraints = [] + for equality in self.equalities: + constraints.append(sympy.Eq(equality.tosympy(), 0)) + for inequality in self.inequalities: + constraints.append(sympy.Ge(inequality.tosympy(), 0)) + return sympy.And(*constraints) + + +def _polymorphic(func): + @functools.wraps(func) + def wrapper(left, right): + if isinstance(left, numbers.Rational): + left = Constant(left) + elif not isinstance(left, Expression): + raise TypeError('left must be a a rational number ' + 'or a linear expression') + if isinstance(right, numbers.Rational): + right = Constant(right) + elif not isinstance(right, Expression): + raise TypeError('right must be a a rational number ' + 'or a linear expression') + return func(left, right) + return wrapper + +@_polymorphic +def Lt(left, right): + return Polyhedron([], [right - left - 1]) + +@_polymorphic +def Le(left, right): + return Polyhedron([], [right - left]) + +@_polymorphic +def Eq(left, right): + return Polyhedron([left - right], []) + +@_polymorphic +def Ne(left, right): + return ~Eq(left, right) + +@_polymorphic +def Gt(left, right): + return Polyhedron([], [left - right - 1]) + +@_polymorphic +def Ge(left, right): + return Polyhedron([], [left - right]) + + +Empty = Eq(1, 0) + +Universe = Polyhedron([]) diff --git a/tests/__init__.py b/pypol/tests/__init__.py similarity index 100% rename from tests/__init__.py rename to pypol/tests/__init__.py diff --git a/pypol/tests/test_domains.py b/pypol/tests/test_domains.py new file mode 100644 index 0000000..f9e7008 --- /dev/null +++ b/pypol/tests/test_domains.py @@ -0,0 +1,12 @@ +import unittest + +from ..domains import * + + +class TestDomain(unittest.TestCase): + + def setUp(self): + pass + + def test_new(self): + pass diff --git a/tests/test_linear.py b/pypol/tests/test_linexprs.py similarity index 80% rename from tests/test_linear.py rename to pypol/tests/test_linexprs.py index 6cd1ff4..1606ea0 100644 --- a/tests/test_linear.py +++ b/pypol/tests/test_linexprs.py @@ -3,7 +3,7 @@ import unittest from fractions import Fraction -from pypol.linear import * +from ..linexprs import * try: @@ -32,15 +32,13 @@ class TestExpression(unittest.TestCase): self.pi = Expression(constant=Fraction(22, 7)) self.expr = self.x - 2*self.y + 3 - def test_new_subclass(self): + def test_new(self): self.assertIsInstance(self.x, Symbol) self.assertIsInstance(self.pi, Constant) self.assertNotIsInstance(self.x + self.pi, Symbol) self.assertNotIsInstance(self.x + self.pi, Constant) xx = Expression({'x': 2}) self.assertNotIsInstance(xx, Symbol) - - def test_new_types(self): with self.assertRaises(TypeError): Expression('x + y', 2) self.assertEqual(Expression({'x': 2}), Expression({self.x: 2})) @@ -49,18 +47,9 @@ class TestExpression(unittest.TestCase): with self.assertRaises(TypeError): Expression({'x': '2'}) self.assertEqual(Expression(constant=1), Expression(constant=self.one)) - with self.assertRaises(TypeError): - Expression(constant='1') - - def test_symbols(self): - self.assertCountEqual(self.x.symbols, ['x']) - self.assertCountEqual(self.pi.symbols, []) - self.assertCountEqual(self.expr.symbols, ['x', 'y']) - - def test_dimension(self): - self.assertEqual(self.x.dimension, 1) - self.assertEqual(self.pi.dimension, 0) - self.assertEqual(self.expr.dimension, 2) + self.assertEqual(Expression(constant='1'), Expression(constant=self.one)) + with self.assertRaises(ValueError): + Expression(constant='a') def test_coefficient(self): self.assertEqual(self.expr.coefficient('x'), 1) @@ -90,19 +79,29 @@ class TestExpression(unittest.TestCase): self.assertEqual(self.pi.constant, Fraction(22, 7)) self.assertEqual(self.expr.constant, 3) + def test_symbols(self): + self.assertCountEqual(self.x.symbols, ['x']) + self.assertCountEqual(self.pi.symbols, []) + self.assertCountEqual(self.expr.symbols, ['x', 'y']) + + def test_dimension(self): + self.assertEqual(self.x.dimension, 1) + self.assertEqual(self.pi.dimension, 0) + self.assertEqual(self.expr.dimension, 2) + def test_isconstant(self): self.assertFalse(self.x.isconstant()) self.assertTrue(self.pi.isconstant()) self.assertFalse(self.expr.isconstant()) - def test_values(self): - self.assertCountEqual(self.expr.values(), [1, -2, 3]) - def test_issymbol(self): self.assertTrue(self.x.issymbol()) self.assertFalse(self.pi.issymbol()) self.assertFalse(self.expr.issymbol()) + def test_values(self): + self.assertCountEqual(self.expr.values(), [1, -2, 3]) + def test_bool(self): self.assertTrue(self.x) self.assertFalse(self.zero) @@ -132,11 +131,28 @@ class TestExpression(unittest.TestCase): self.assertEqual(0 * self.expr, 0) self.assertEqual(self.expr * 2, 2*self.x - 4*self.y + 6) - def test_div(self): + def test_truediv(self): with self.assertRaises(ZeroDivisionError): self.expr / 0 self.assertEqual(self.expr / 2, self.x / 2 - self.y + Fraction(3, 2)) + def test_eq(self): + self.assertEqual(self.expr, self.expr) + self.assertNotEqual(self.x, self.y) + self.assertEqual(self.zero, 0) + + def test__toint(self): + self.assertEqual((self.x + self.y/2 + self.z/3)._toint(), + 6*self.x + 3*self.y + 2*self.z) + + def test_fromstring(self): + self.assertEqual(Expression.fromstring('x'), self.x) + self.assertEqual(Expression.fromstring('-x'), -self.x) + self.assertEqual(Expression.fromstring('22/7'), self.pi) + self.assertEqual(Expression.fromstring('x - 2y + 3'), self.expr) + self.assertEqual(Expression.fromstring('x - (3-1)y + 3'), self.expr) + self.assertEqual(Expression.fromstring('x - 2*y + 3'), self.expr) + def test_str(self): self.assertEqual(str(Expression()), '0') self.assertEqual(str(self.x), 'x') @@ -151,23 +167,6 @@ class TestExpression(unittest.TestCase): self.assertEqual(repr(self.x + self.one), "Expression('x + 1')") self.assertEqual(repr(self.expr), "Expression('x - 2*y + 3')") - def test_fromstring(self): - self.assertEqual(Expression.fromstring('x'), self.x) - self.assertEqual(Expression.fromstring('-x'), -self.x) - self.assertEqual(Expression.fromstring('22/7'), self.pi) - self.assertEqual(Expression.fromstring('x - 2y + 3'), self.expr) - self.assertEqual(Expression.fromstring('x - (3-1)y + 3'), self.expr) - self.assertEqual(Expression.fromstring('x - 2*y + 3'), self.expr) - - def test_eq(self): - self.assertEqual(self.expr, self.expr) - self.assertNotEqual(self.x, self.y) - self.assertEqual(self.zero, 0) - - def test__toint(self): - self.assertEqual((self.x + self.y/2 + self.z/3)._toint(), - 6*self.x + 3*self.y + 2*self.z) - @_requires_sympy def test_fromsympy(self): sp_x, sp_y = sympy.symbols('x y') @@ -185,33 +184,34 @@ class TestExpression(unittest.TestCase): self.assertEqual(self.expr.tosympy(), sp_x - 2*sp_y + 3) -class TestConstant(unittest.TestCase): - - def setUp(self): - self.zero = Constant(0) - self.one = Constant(1) - self.pi = Constant(Fraction(22, 7)) - - @_requires_sympy - def test_fromsympy(self): - self.assertEqual(Constant.fromsympy(sympy.Rational(22, 7)), self.pi) - with self.assertRaises(TypeError): - Constant.fromsympy(sympy.Symbol('x')) - - class TestSymbol(unittest.TestCase): def setUp(self): self.x = Symbol('x') self.y = Symbol('y') + def test_new(self): + self.assertEqual(Symbol(' x '), self.x) + self.assertEqual(Symbol(self.x), self.x) + with self.assertRaises(TypeError): + Symbol(1) + def test_name(self): self.assertEqual(self.x.name, 'x') - def test_symbols(self): - self.assertListEqual(list(symbols('x y')), [self.x, self.y]) - self.assertListEqual(list(symbols('x,y')), [self.x, self.y]) - self.assertListEqual(list(symbols(['x', 'y'])), [self.x, self.y]) + def test_issymbol(self): + self.assertTrue(self.x.issymbol()) + + def test_fromstring(self): + self.assertEqual(Symbol.fromstring('x'), self.x) + with self.assertRaises(SyntaxError): + Symbol.fromstring('1') + + def test_str(self): + self.assertEqual(str(self.x), 'x') + + def test_repr(self): + self.assertEqual(repr(self.x), "Symbol('x')") @_requires_sympy def test_fromsympy(self): @@ -224,75 +224,46 @@ class TestSymbol(unittest.TestCase): with self.assertRaises(TypeError): Symbol.fromsympy(sp_x*sp_x) - -class TestOperators(unittest.TestCase): - - pass + def test_symbols(self): + self.assertListEqual(list(symbols('x y')), [self.x, self.y]) + self.assertListEqual(list(symbols('x,y')), [self.x, self.y]) + self.assertListEqual(list(symbols(['x', 'y'])), [self.x, self.y]) -class TestPolyhedron(unittest.TestCase): +class TestConstant(unittest.TestCase): def setUp(self): - x, y = symbols('x y') - self.square = Polyhedron(inequalities=[x, 1 - x, y, 1 - y]) - - def test_symbols(self): - self.assertCountEqual(self.square.symbols, ['x', 'y']) + self.zero = Constant(0) + self.one = Constant(1) + self.pi = Constant(Fraction(22, 7)) - def test_dimension(self): - self.assertEqual(self.square.dimension, 2) + def test_new(self): + self.assertEqual(Constant(), self.zero) + self.assertEqual(Constant(1), self.one) + self.assertEqual(Constant(self.pi), self.pi) + self.assertEqual(Constant('22/7'), self.pi) - def test_str(self): - self.assertEqual(str(self.square), - 'x >= 0, -x + 1 >= 0, y >= 0, -y + 1 >= 0') + def test_isconstant(self): + self.assertTrue(self.zero.isconstant()) - def test_repr(self): - self.assertEqual(repr(self.square), - "Polyhedron('x >= 0, -x + 1 >= 0, y >= 0, -y + 1 >= 0')") + def test_bool(self): + self.assertFalse(self.zero) + self.assertTrue(self.pi) def test_fromstring(self): - self.assertEqual(Polyhedron.fromstring('{x >= 0, -x + 1 >= 0, ' - 'y >= 0, -y + 1 >= 0}'), self.square) - - def test_isempty(self): - self.assertFalse(self.square.isempty()) + self.assertEqual(Constant.fromstring('22/7'), self.pi) + with self.assertRaises(ValueError): + Constant.fromstring('a') + with self.assertRaises(TypeError): + Constant.fromstring(1) - def test_isuniverse(self): - self.assertFalse(self.square.isuniverse()) + def test_repr(self): + self.assertEqual(repr(self.zero), 'Constant(0)') + self.assertEqual(repr(self.one), 'Constant(1)') + self.assertEqual(repr(self.pi), 'Constant(22, 7)') - @unittest.expectedFailure @_requires_sympy def test_fromsympy(self): - sp_x, sp_y = sympy.symbols('x y') - self.assertEqual(Polyhedron.fromsympy((sp_x >= 0) & (sp_x <= 1) & - (sp_y >= 0) & (sp_y <= 1)), self.square) - - @_requires_sympy - def test_tosympy(self): - sp_x, sp_y = sympy.symbols('x y') - self.assertEqual(self.square.tosympy(), - sympy.And(-sp_x + 1 >= 0, -sp_y + 1 >= 0, sp_x >= 0, sp_y >= 0)) - - -class TestEmpty: - - def test_repr(self): - self.assertEqual(repr(Empty), 'Empty') - - def test_isempty(self): - self.assertTrue(Empty.isempty()) - - def test_isuniverse(self): - self.assertFalse(Empty.isuniverse()) - - -class TestUniverse: - - def test_repr(self): - self.assertEqual(repr(Universe), 'Universe') - - def test_isempty(self): - self.assertTrue(Universe.isempty()) - - def test_isuniverse(self): - self.assertTrue(Universe.isuniverse()) + self.assertEqual(Constant.fromsympy(sympy.Rational(22, 7)), self.pi) + with self.assertRaises(TypeError): + Constant.fromsympy(sympy.Symbol('x')) diff --git a/pypol/tests/test_polyhedra.py b/pypol/tests/test_polyhedra.py new file mode 100644 index 0000000..c74e25f --- /dev/null +++ b/pypol/tests/test_polyhedra.py @@ -0,0 +1,87 @@ +import functools +import unittest + +from ..linexprs import symbols +from ..polyhedra import * + + +try: + import sympy + def _requires_sympy(func): + @functools.wraps(func) + def wrapper(self): + return func(self) + return wrapper +except ImportError: + def _requires_sympy(func): + @functools.wraps(func) + def wrapper(self): + raise unittest.SkipTest('SymPy is not available') + return wrapper + + +class TestPolyhedron(unittest.TestCase): + + def setUp(self): + x, y = symbols('x y') + self.square = Polyhedron(inequalities=[x, 1 - x, y, 1 - y]) + + def test_symbols(self): + self.assertCountEqual(self.square.symbols, ['x', 'y']) + + def test_dimension(self): + self.assertEqual(self.square.dimension, 2) + + def test_str(self): + self.assertEqual(str(self.square), + 'And(Ge(x, 0), Ge(-x + 1, 0), Ge(y, 0), Ge(-y + 1, 0))') + + def test_repr(self): + self.assertEqual(repr(self.square), + "And(Ge(x, 0), Ge(-x + 1, 0), Ge(y, 0), Ge(-y + 1, 0))") + + def test_fromstring(self): + self.assertEqual(Polyhedron.fromstring('{x >= 0, -x + 1 >= 0, ' + 'y >= 0, -y + 1 >= 0}'), self.square) + + def test_isempty(self): + self.assertFalse(self.square.isempty()) + + def test_isuniverse(self): + self.assertFalse(self.square.isuniverse()) + + @_requires_sympy + def test_fromsympy(self): + sp_x, sp_y = sympy.symbols('x y') + self.assertEqual(Polyhedron.fromsympy((sp_x >= 0) & (sp_x <= 1) & + (sp_y >= 0) & (sp_y <= 1)), self.square) + + @_requires_sympy + def test_tosympy(self): + sp_x, sp_y = sympy.symbols('x y') + self.assertEqual(self.square.tosympy(), + sympy.And(-sp_x + 1 >= 0, -sp_y + 1 >= 0, sp_x >= 0, sp_y >= 0)) + + +class TestEmpty: + + def test_repr(self): + self.assertEqual(repr(Empty), 'Empty') + + def test_isempty(self): + self.assertTrue(Empty.isempty()) + + def test_isuniverse(self): + self.assertFalse(Empty.isuniverse()) + + +class TestUniverse: + + def test_repr(self): + self.assertEqual(repr(Universe), 'Universe') + + def test_isempty(self): + self.assertTrue(Universe.isempty()) + + def test_isuniverse(self): + self.assertTrue(Universe.isuniverse()) diff --git a/setup.py b/setup.py index b530f28..368983e 100755 --- a/setup.py +++ b/setup.py @@ -8,8 +8,8 @@ setup( author='MINES ParisTech', packages=['pypol'], ext_modules = [ - Extension('pypol._isl', - sources=['pypol/_isl.c'], + Extension('pypol._islhelper', + sources=['pypol/_islhelper.c'], libraries=['isl']) ] ) diff --git a/tests/test_isl.py b/tests/test_isl.py deleted file mode 100644 index 21374b5..0000000 --- a/tests/test_isl.py +++ /dev/null @@ -1,13 +0,0 @@ -import unittest - -from math import floor, ceil, trunc - -from pypol.isl import * - - -class TestContext(unittest.TestCase): - - def test_eq(self): - ctx1, ctx2 = Context(), Context() - self.assertEqual(ctx1, ctx1) - self.assertNotEqual(ctx1, ctx2) -- 2.20.1