X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/1c48ba9f3483505f53962731dc58c6c02d785fc4..7f60c3d845f3035ce675bfe4cc0d2d01456013c6:/pypol/polyhedra.py diff --git a/pypol/polyhedra.py b/pypol/polyhedra.py index d181646..6b44bdc 100644 --- a/pypol/polyhedra.py +++ b/pypol/polyhedra.py @@ -4,7 +4,7 @@ import numbers from . import islhelper from .islhelper import mainctx, libisl -from .linexprs import Expression, Constant +from .linexprs import Expression, Rational from .domains import Domain @@ -44,14 +44,14 @@ class Polyhedron(Domain): for i, equality in enumerate(equalities): if not isinstance(equality, Expression): raise TypeError('equalities must be linear expressions') - equalities[i] = equality._toint() + equalities[i] = equality.scaleint() if inequalities is None: inequalities = [] else: for i, inequality in enumerate(inequalities): if not isinstance(inequality, Expression): raise TypeError('inequalities must be linear expressions') - inequalities[i] = inequality._toint() + inequalities[i] = inequality.scaleint() symbols = cls._xsymbols(equalities + inequalities) islbset = cls._toislbasicset(equalities, inequalities, symbols) return cls._fromislbasicset(islbset, symbols) @@ -91,12 +91,12 @@ class Polyhedron(Domain): equalities = [] inequalities = [] for islconstraint in islconstraints: - islpr = libisl.isl_printer_to_str(mainctx) constant = libisl.isl_constraint_get_constant_val(islconstraint) constant = islhelper.isl_val_to_int(constant) coefficients = {} - for dim, symbol in enumerate(symbols): - coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint, libisl.isl_dim_set, dim) + for index, symbol in enumerate(symbols): + coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint, + libisl.isl_dim_set, index) coefficient = islhelper.isl_val_to_int(coefficient) if coefficient != 0: coefficients[symbol] = coefficient @@ -117,34 +117,35 @@ class Polyhedron(Domain): @classmethod def _toislbasicset(cls, equalities, inequalities, symbols): dimension = len(symbols) + indices = {symbol: index for index, symbol in enumerate(symbols)} islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension) islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp)) islls = libisl.isl_local_space_from_space(islsp) for equality in equalities: isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls)) for symbol, coefficient in equality.coefficients(): - val = str(coefficient).encode() - val = libisl.isl_val_read_from_str(mainctx, val) - sid = symbols.index(symbol) + islval = str(coefficient).encode() + islval = libisl.isl_val_read_from_str(mainctx, islval) + index = indices[symbol] isleq = libisl.isl_constraint_set_coefficient_val(isleq, - libisl.isl_dim_set, sid, val) + libisl.isl_dim_set, index, islval) if equality.constant != 0: - val = str(equality.constant).encode() - val = libisl.isl_val_read_from_str(mainctx, val) - isleq = libisl.isl_constraint_set_constant_val(isleq, val) + islval = str(equality.constant).encode() + islval = libisl.isl_val_read_from_str(mainctx, islval) + isleq = libisl.isl_constraint_set_constant_val(isleq, islval) islbset = libisl.isl_basic_set_add_constraint(islbset, isleq) for inequality in inequalities: islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls)) for symbol, coefficient in inequality.coefficients(): - val = str(coefficient).encode() - val = libisl.isl_val_read_from_str(mainctx, val) - sid = symbols.index(symbol) + islval = str(coefficient).encode() + islval = libisl.isl_val_read_from_str(mainctx, islval) + index = indices[symbol] islin = libisl.isl_constraint_set_coefficient_val(islin, - libisl.isl_dim_set, sid, val) + libisl.isl_dim_set, index, islval) if inequality.constant != 0: - val = str(inequality.constant).encode() - val = libisl.isl_val_read_from_str(mainctx, val) - islin = libisl.isl_constraint_set_constant_val(islin, val) + islval = str(inequality.constant).encode() + islval = libisl.isl_val_read_from_str(mainctx, islval) + islin = libisl.isl_constraint_set_constant_val(islin, islval) islbset = libisl.isl_basic_set_add_constraint(islbset, islin) return islbset @@ -171,38 +172,12 @@ class Polyhedron(Domain): else: return 'And({})'.format(', '.join(strings)) - @classmethod - def _fromsympy(cls, expr): - import sympy - equalities = [] - inequalities = [] - if expr.func == sympy.And: - for arg in expr.args: - arg_eqs, arg_ins = cls._fromsympy(arg) - equalities.extend(arg_eqs) - inequalities.extend(arg_ins) - elif expr.func == sympy.Eq: - expr = Expression.fromsympy(expr.args[0] - expr.args[1]) - equalities.append(expr) - else: - if expr.func == sympy.Lt: - expr = Expression.fromsympy(expr.args[1] - expr.args[0] - 1) - elif expr.func == sympy.Le: - expr = Expression.fromsympy(expr.args[1] - expr.args[0]) - elif expr.func == sympy.Ge: - expr = Expression.fromsympy(expr.args[0] - expr.args[1]) - elif expr.func == sympy.Gt: - expr = Expression.fromsympy(expr.args[0] - expr.args[1] - 1) - else: - raise ValueError('non-polyhedral expression: {!r}'.format(expr)) - inequalities.append(expr) - return equalities, inequalities - @classmethod def fromsympy(cls, expr): - import sympy - equalities, inequalities = cls._fromsympy(expr) - return cls(equalities, inequalities) + domain = Domain.fromsympy(expr) + if not isinstance(domain, Polyhedron): + raise ValueError('non-polyhedral expression: {!r}'.format(expr)) + return domain def tosympy(self): import sympy @@ -218,12 +193,12 @@ def _polymorphic(func): @functools.wraps(func) def wrapper(left, right): if isinstance(left, numbers.Rational): - left = Constant(left) + left = Rational(left) elif not isinstance(left, Expression): raise TypeError('left must be a a rational number ' 'or a linear expression') if isinstance(right, numbers.Rational): - right = Constant(right) + right = Rational(right) elif not isinstance(right, Expression): raise TypeError('right must be a a rational number ' 'or a linear expression')