From: Vivien Maisonneuve Date: Wed, 2 Jul 2014 17:12:07 +0000 (+0200) Subject: Cleaner and faster linear expressions X-Git-Tag: 1.0~163 X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/commitdiff_plain/29ed88d1a15d283ea6f3340a4dd97e8cc7c2d2d4 Cleaner and faster linear expressions --- diff --git a/pypol/domains.py b/pypol/domains.py index 12cf471..9187081 100644 --- a/pypol/domains.py +++ b/pypol/domains.py @@ -50,7 +50,7 @@ class Domain: symbols = set() for item in iterator: symbols.update(item.symbols) - return tuple(sorted(symbols)) + return tuple(sorted(symbols, key=lambda symbol: symbol.name)) @property def polyhedra(self): @@ -139,7 +139,7 @@ class Domain: def simplify(self): #does not change anything in any of the examples - #isl seems to do this naturally + #isl seems to do this naturally islset = self._toislset(self.polyhedra, self.symbols) islset = libisl.isl_set_remove_redundancies(islset) return self._fromislset(islset, self.symbols) @@ -152,30 +152,21 @@ class Domain: islbset = libisl.isl_set_polyhedral_hull(islset) return Polyhedron._fromislbasicset(islbset, self.symbols) - def drop_dims(self, dims): - # use to remove certain variables use isl_set_drop_constraints_involving_dims instead? - from .polyhedra import Polyhedron - n = 0 - dims = sorted(dims) - symbols = sorted(self.symbols) + def project_out(self, symbols): + # use to remove certain variables islset = self._toislset(self.polyhedra, self.symbols) - for dim in dims: - dim_index = dims.index(dim) - if dim in symbols: - first = symbols.index(dim) - try: - for dim in dims: - if symbols[first+1] is dims[dim_index+1]: #check if next value in symbols is same as next value in dims - n += 1 - islbset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, first, n) - symbols.remove(dim) - except: - islbset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, first, 1) - symbols.remove(dim) - else: - islbset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, 0) - return Polyhedron._fromislset(islbset, symbols) - + n = 0 + for index, symbol in reversed(list(enumerate(self.symbols))): + if symbol in symbols: + n += 1 + elif n > 0: + islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n) + n = 0 + if n > 0: + islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n) + symbols = [symbol for symbol in self.symbols if symbol not in symbols] + return Domain._fromislset(islset, symbols) + def sample(self): from .polyhedra import Polyhedron islset = self._toislset(self.polyhedra, self.symbols) diff --git a/pypol/linexprs.py b/pypol/linexprs.py index 10daf9d..ccd1564 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -3,13 +3,13 @@ import functools import numbers import re -from collections import OrderedDict +from collections import OrderedDict, defaultdict from fractions import Fraction, gcd __all__ = [ 'Expression', - 'Symbol', 'symbols', 'symbolname', 'symbolnames', + 'Symbol', 'symbols', 'Constant', ] @@ -36,37 +36,38 @@ class Expression: '_constant', '_symbols', '_dimension', - '_hash', ) def __new__(cls, coefficients=None, constant=0): if isinstance(coefficients, str): if constant: raise TypeError('too many arguments') - return cls.fromstring(coefficients) - if isinstance(coefficients, dict): - coefficients = coefficients.items() + return Expression.fromstring(coefficients) if coefficients is None: return Constant(constant) + if isinstance(coefficients, dict): + coefficients = coefficients.items() + for symbol, coefficient in coefficients: + if not isinstance(symbol, Symbol): + raise TypeError('symbols must be Symbol instances') coefficients = [(symbol, coefficient) for symbol, coefficient in coefficients if coefficient != 0] if len(coefficients) == 0: return Constant(constant) - elif len(coefficients) == 1 and constant == 0: + if len(coefficients) == 1 and constant == 0: symbol, coefficient = coefficients[0] if coefficient == 1: - return Symbol(symbol) + return symbol self = object().__new__(cls) - self._coefficients = {} - for symbol, coefficient in coefficients: - symbol = symbolname(symbol) + self._coefficients = OrderedDict() + for symbol, coefficient in sorted(coefficients, + key=lambda item: item[0].name): if isinstance(coefficient, Constant): coefficient = coefficient.constant if not isinstance(coefficient, numbers.Rational): raise TypeError('coefficients must be rational numbers ' 'or Constant instances') self._coefficients[symbol] = coefficient - self._coefficients = OrderedDict(sorted(self._coefficients.items())) if isinstance(constant, Constant): constant = constant.constant if not isinstance(constant, numbers.Rational): @@ -75,11 +76,11 @@ class Expression: self._constant = constant self._symbols = tuple(self._coefficients) self._dimension = len(self._symbols) - self._hash = hash((tuple(self._coefficients.items()), self._constant)) return self def coefficient(self, symbol): - symbol = symbolname(symbol) + if not isinstance(symbol, Symbol): + raise TypeError('symbol must be a Symbol instance') try: return self._coefficients[symbol] except KeyError: @@ -103,7 +104,7 @@ class Expression: return self._dimension def __hash__(self): - return self._hash + return hash((tuple(self._coefficients.items()), self._constant)) def isconstant(self): return False @@ -112,8 +113,7 @@ class Expression: return False def values(self): - for symbol in self.symbols: - yield self.coefficient(symbol) + yield from self._coefficients.values() yield self.constant def __bool__(self): @@ -127,12 +127,9 @@ class Expression: @_polymorphic def __add__(self, other): - coefficients = dict(self.coefficients()) + coefficients = defaultdict(Constant, self.coefficients()) for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] += coefficient - else: - coefficients[symbol] = coefficient + coefficients[symbol] += coefficient constant = self.constant + other.constant return Expression(coefficients, constant) @@ -140,12 +137,9 @@ class Expression: @_polymorphic def __sub__(self, other): - coefficients = dict(self.coefficients()) + coefficients = defaultdict(Constant, self.coefficients()) for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] -= coefficient - else: - coefficients[symbol] = -coefficient + coefficients[symbol] -= coefficient constant = self.constant - other.constant return Expression(coefficients, constant) @@ -172,9 +166,8 @@ class Expression: if other.isconstant(): coefficients = dict(self.coefficients()) for symbol in coefficients: - coefficients[symbol] = \ - Fraction(coefficients[symbol], other.constant) - constant = Fraction(self.constant, other.constant) + coefficients[symbol] = Constant(coefficients[symbol], other.constant) + constant = Constant(self.constant, other.constant) return Expression(coefficients, constant) if isinstance(other, Expression): raise ValueError('non-linear expression: ' @@ -184,8 +177,7 @@ class Expression: def __rtruediv__(self, other): if isinstance(other, self): if self.isconstant(): - constant = Fraction(other, self.constant) - return Expression(constant=constant) + return Constant(other, self.constant) else: raise ValueError('non-linear expression: ' '{} / {}'.format(other._parenstr(), self._parenstr())) @@ -196,8 +188,8 @@ class Expression: # "normal" equality # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs return isinstance(other, Expression) and \ - self._coefficients == other._coefficients and \ - self.constant == other.constant + self._coefficients == other._coefficients and \ + self.constant == other.constant @_polymorphic def __le__(self, other): @@ -219,11 +211,28 @@ class Expression: from .polyhedra import Gt return Gt(self, other) - def _toint(self): + def scaleint(self): lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), [value.denominator for value in self.values()]) return self * lcm + def subs(self, symbol, expression=None): + if expression is None: + if isinstance(symbol, dict): + symbol = symbol.items() + substitutions = symbol + else: + substitutions = [(symbol, expression)] + result = self + for symbol, expression in substitutions: + coefficients = [(othersymbol, coefficient) + for othersymbol, coefficient in result.coefficients() + if othersymbol != symbol] + coefficient = result.coefficient(symbol) + constant = result.constant + result = Expression(coefficients, constant) + coefficient*expression + return result + @classmethod def _fromast(cls, node): if isinstance(node, ast.Module) and len(node.body) == 1: @@ -249,46 +258,23 @@ class Expression: return left / right raise SyntaxError('invalid syntax') - def subs(self, symbol, expression=None): - if expression is None: - if isinstance(symbol, dict): - symbol = symbol.items() - substitutions = symbol - else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: - symbol = symbolname(symbol) - result = result._subs(symbol, expression) - return result - - def _subs(self, symbol, expression): - coefficients = {name: coefficient - for name, coefficient in self.coefficients() - if name != symbol} - constant = self.constant - coefficient = self.coefficient(symbol) - result = Expression(coefficients, self.constant) - result += coefficient * expression - return result - _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') @classmethod def fromstring(cls, string): # add implicit multiplication operators, e.g. '5x' -> '5*x' - string = cls._RE_NUM_VAR.sub(r'\1*\2', string) + string = Expression._RE_NUM_VAR.sub(r'\1*\2', string) tree = ast.parse(string, 'eval') return cls._fromast(tree) - def __str__(self): + def __repr__(self): string = '' i = 0 for symbol in self.symbols: coefficient = self.coefficient(symbol) if coefficient == 1: if i == 0: - string += symbol + string += symbol.name else: string += ' + {}'.format(symbol) elif coefficient == -1: @@ -325,30 +311,27 @@ class Expression: else: return '({})'.format(string) - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, str(self)) - @classmethod def fromsympy(cls, expr): import sympy - coefficients = {} + coefficients = [] constant = 0 for symbol, coefficient in expr.as_coefficients_dict().items(): coefficient = Fraction(coefficient.p, coefficient.q) if symbol == sympy.S.One: constant = coefficient elif isinstance(symbol, sympy.Symbol): - symbol = symbol.name - coefficients[symbol] = coefficient + symbol = Symbol(symbol.name) + coefficients.append((symbol, coefficient)) else: raise ValueError('non-linear expression: {!r}'.format(expr)) - return cls(coefficients, constant) + return Expression(coefficients, constant) def tosympy(self): import sympy expr = 0 for symbol, coefficient in self.coefficients(): - term = coefficient * sympy.Symbol(symbol) + term = coefficient * sympy.Symbol(symbol.name) expr += term expr += self.constant return expr @@ -358,14 +341,13 @@ class Symbol(Expression): __slots__ = ( '_name', - '_hash', ) def __new__(cls, name): - name = symbolname(name) + if not isinstance(name, str): + raise TypeError('name must be a string') self = object().__new__(cls) - self._name = name - self._hash = hash(self._name) + self._name = name.strip() return self @property @@ -373,17 +355,18 @@ class Symbol(Expression): return self._name def __hash__(self): - return self._hash + return hash(self._name) def coefficient(self, symbol): - symbol = symbolname(symbol) - if symbol == self.name: + if not isinstance(symbol, Symbol): + raise TypeError('symbol must be a Symbol instance') + if symbol == self: return 1 else: return 0 def coefficients(self): - yield self.name, 1 + yield self, 1 @property def constant(self): @@ -391,7 +374,7 @@ class Symbol(Expression): @property def symbols(self): - return self.name, + return self, @property def dimension(self): @@ -400,6 +383,9 @@ class Symbol(Expression): def issymbol(self): return True + def values(self): + yield 1 + def __eq__(self, other): return isinstance(other, Symbol) and self.name == other.name @@ -413,14 +399,11 @@ class Symbol(Expression): return Symbol(node.id) raise SyntaxError('invalid syntax') - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, self._name) - @classmethod def fromsympy(cls, expr): import sympy if isinstance(expr, sympy.Symbol): - return cls(expr.name) + return Symbol(expr.name) else: raise TypeError('expr must be a sympy.Symbol instance') @@ -428,27 +411,13 @@ class Symbol(Expression): def symbols(names): if isinstance(names, str): names = names.replace(',', ' ').split() - return (Symbol(name) for name in names) - -def symbolname(symbol): - if isinstance(symbol, str): - return symbol.strip() - elif isinstance(symbol, Symbol): - return symbol.name - else: - raise TypeError('symbol must be a string or a Symbol instance') - -def symbolnames(symbols): - if isinstance(symbols, str): - return symbols.replace(',', ' ').split() - return tuple(symbolname(symbol) for symbol in symbols) + return tuple(Symbol(name) for name in names) class Constant(Expression): __slots__ = ( '_constant', - '_hash', ) def __new__(cls, numerator=0, denominator=None): @@ -457,18 +426,18 @@ class Constant(Expression): self._constant = numerator.constant else: self._constant = Fraction(numerator, denominator) - self._hash = hash(self._constant) return self def __hash__(self): - return self._hash + return hash(self.constant) def coefficient(self, symbol): - symbol = symbolname(symbol) + if not isinstance(symbol, Symbol): + raise TypeError('symbol must be a Symbol instance') return 0 def coefficients(self): - yield from [] + yield from () @property def symbols(self): @@ -481,6 +450,9 @@ class Constant(Expression): def isconstant(self): return True + def values(self): + yield self._constant + @_polymorphic def __eq__(self, other): return isinstance(other, Constant) and self.constant == other.constant @@ -490,25 +462,16 @@ class Constant(Expression): @classmethod def fromstring(cls, string): - if isinstance(string, str): - return Constant(Fraction(string)) - else: + if not isinstance(string, str): raise TypeError('string must be a string instance') - - def __repr__(self): - if self.constant.denominator == 1: - return '{}({!r})'.format(self.__class__.__name__, - self.constant.numerator) - else: - return '{}({!r}, {!r})'.format(self.__class__.__name__, - self.constant.numerator, self.constant.denominator) + return Constant(Fraction(string)) @classmethod def fromsympy(cls, expr): import sympy if isinstance(expr, sympy.Rational): - return cls(expr.p, expr.q) + return Constant(expr.p, expr.q) elif isinstance(expr, numbers.Rational): - return cls(expr) + return Constant(expr) else: raise TypeError('expr must be a sympy.Rational instance') diff --git a/pypol/polyhedra.py b/pypol/polyhedra.py index ac67cf8..6ef7cc1 100644 --- a/pypol/polyhedra.py +++ b/pypol/polyhedra.py @@ -44,14 +44,14 @@ class Polyhedron(Domain): for i, equality in enumerate(equalities): if not isinstance(equality, Expression): raise TypeError('equalities must be linear expressions') - equalities[i] = equality._toint() + equalities[i] = equality.scaleint() if inequalities is None: inequalities = [] else: for i, inequality in enumerate(inequalities): if not isinstance(inequality, Expression): raise TypeError('inequalities must be linear expressions') - inequalities[i] = inequality._toint() + inequalities[i] = inequality.scaleint() symbols = cls._xsymbols(equalities + inequalities) islbset = cls._toislbasicset(equalities, inequalities, symbols) return cls._fromislbasicset(islbset, symbols) @@ -95,7 +95,8 @@ class Polyhedron(Domain): constant = islhelper.isl_val_to_int(constant) coefficients = {} for index, symbol in enumerate(symbols): - coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint, libisl.isl_dim_set, index) + coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint, + libisl.isl_dim_set, index) coefficient = islhelper.isl_val_to_int(coefficient) if coefficient != 0: coefficients[symbol] = coefficient diff --git a/pypol/tests/test_domains.py b/pypol/tests/test_domains.py index 55853fd..e9dd1cd 100644 --- a/pypol/tests/test_domains.py +++ b/pypol/tests/test_domains.py @@ -1,7 +1,7 @@ import unittest from ..domains import * -from ..linexprs import symbols +from ..linexprs import Symbol, symbols from ..polyhedra import * @@ -78,10 +78,10 @@ class TestDomain(unittest.TestCase): self.assertEqual(self.square1.polyhedral_hull(), self.hull) def test_project_out(self): - self.assertEqual(self.square1.project_out('x'), self.dropped) - self.assertEqual(self.square1.project_out('x y'), self.universe) - self.assertEqual(self.universe.project_out(' '), self.universe) - self.assertEqual(self.empty.project_out(' '), Empty) + self.assertEqual(self.square1.project_out(symbols('x')), self.dropped) + self.assertEqual(self.square1.project_out(symbols('x y')), self.universe) + self.assertEqual(self.universe.project_out([]), self.universe) + self.assertEqual(self.empty.project_out([]), Empty) def test_simplify(self): self.assertEqual(self.universe.simplify(), self.universe) diff --git a/pypol/tests/test_linexprs.py b/pypol/tests/test_linexprs.py index bc062b6..8dfd13e 100644 --- a/pypol/tests/test_linexprs.py +++ b/pypol/tests/test_linexprs.py @@ -10,9 +10,9 @@ from .libhelper import requires_sympy class TestExpression(unittest.TestCase): def setUp(self): - self.x = Expression({'x': 1}) - self.y = Expression({'y': 1}) - self.z = Expression({'z': 1}) + self.x = Symbol('x') + self.y = Symbol('y') + self.z = Symbol('z') self.zero = Expression(constant=0) self.one = Expression(constant=1) self.pi = Expression(constant=Fraction(22, 7)) @@ -23,11 +23,10 @@ class TestExpression(unittest.TestCase): self.assertIsInstance(self.pi, Constant) self.assertNotIsInstance(self.x + self.pi, Symbol) self.assertNotIsInstance(self.x + self.pi, Constant) - xx = Expression({'x': 2}) + xx = Expression({self.x: 2}) self.assertNotIsInstance(xx, Symbol) with self.assertRaises(TypeError): Expression('x + y', 2) - self.assertEqual(Expression({'x': 2}), Expression({self.x: 2})) with self.assertRaises(TypeError): Expression({0: 2}) with self.assertRaises(TypeError): @@ -38,27 +37,29 @@ class TestExpression(unittest.TestCase): Expression(constant='a') def test_coefficient(self): - self.assertEqual(self.expr.coefficient('x'), 1) - self.assertEqual(self.expr.coefficient('y'), -2) + self.assertEqual(self.expr.coefficient(self.x), 1) self.assertEqual(self.expr.coefficient(self.y), -2) - self.assertEqual(self.expr.coefficient('z'), 0) + self.assertEqual(self.expr.coefficient(self.z), 0) + with self.assertRaises(TypeError): + self.expr.coefficients('x') with self.assertRaises(TypeError): self.expr.coefficient(0) with self.assertRaises(TypeError): self.expr.coefficient(self.expr) def test_getitem(self): - self.assertEqual(self.expr['x'], 1) - self.assertEqual(self.expr['y'], -2) + self.assertEqual(self.expr[self.x], 1) self.assertEqual(self.expr[self.y], -2) - self.assertEqual(self.expr['z'], 0) + self.assertEqual(self.expr[self.z], 0) + with self.assertRaises(TypeError): + self.assertEqual(self.expr['x'], 1) with self.assertRaises(TypeError): self.expr[0] with self.assertRaises(TypeError): self.expr[self.expr] def test_coefficients(self): - self.assertCountEqual(self.expr.coefficients(), [('x', 1), ('y', -2)]) + self.assertCountEqual(self.expr.coefficients(), [(self.x, 1), (self.y, -2)]) def test_constant(self): self.assertEqual(self.x.constant, 0) @@ -66,9 +67,9 @@ class TestExpression(unittest.TestCase): self.assertEqual(self.expr.constant, 3) def test_symbols(self): - self.assertCountEqual(self.x.symbols, ['x']) + self.assertCountEqual(self.x.symbols, [self.x]) self.assertCountEqual(self.pi.symbols, []) - self.assertCountEqual(self.expr.symbols, ['x', 'y']) + self.assertCountEqual(self.expr.symbols, [self.x, self.y]) def test_dimension(self): self.assertEqual(self.x.dimension, 1) @@ -127,21 +128,29 @@ class TestExpression(unittest.TestCase): self.assertNotEqual(self.x, self.y) self.assertEqual(self.zero, 0) - def test__toint(self): - self.assertEqual((self.x + self.y/2 + self.z/3)._toint(), + def test_scaleint(self): + self.assertEqual((self.x + self.y/2 + self.z/3).scaleint(), 6*self.x + 3*self.y + 2*self.z) def test_subs(self): - self.assertEqual(self.x.subs('x', 3), 3) - self.assertEqual(self.x.subs('x', self.x), self.x) - self.assertEqual(self.x.subs('x', self.y), self.y) - self.assertEqual(self.x.subs('x', self.x + self.y), self.x + self.y) - self.assertEqual(self.x.subs('y', 3), self.x) - self.assertEqual(self.pi.subs('x', 3), self.pi) - self.assertEqual(self.expr.subs('x', -3), -2 * self.y) - self.assertEqual(self.expr.subs([('x', self.y), ('y', self.x)]), 3 - self.x) - self.assertEqual(self.expr.subs({'x': self.z, 'y': self.z}), 3 - self.z) + self.assertEqual(self.x.subs(self.x, 3), 3) + self.assertEqual(self.x.subs(self.x, self.x), self.x) + self.assertEqual(self.x.subs(self.x, self.y), self.y) + self.assertEqual(self.x.subs(self.x, self.x + self.y), self.x + self.y) + self.assertEqual(self.x.subs(self.y, 3), self.x) + self.assertEqual(self.pi.subs(self.x, 3), self.pi) + self.assertEqual(self.expr.subs(self.x, -3), -2 * self.y) + self.assertEqual(self.expr.subs([(self.x, self.y), (self.y, self.x)]), 3 - self.x) self.assertEqual(self.expr.subs({self.x: self.z, self.y: self.z}), 3 - self.z) + self.assertEqual(self.expr.subs({self.x: self.z, self.y: self.z}), 3 - self.z) + with self.assertRaises(TypeError): + self.x.subs('x', 3) + with self.assertRaises(TypeError): + self.expr.subs([('x', self.z), ('y', self.z)]) + with self.assertRaises(TypeError): + self.expr.subs({'x': self.z, 'y': self.z}) + with self.assertRaises(TypeError): + self.expr.subs(self.x, 'x') def test_fromstring(self): self.assertEqual(Expression.fromstring('x'), self.x) @@ -151,20 +160,13 @@ class TestExpression(unittest.TestCase): self.assertEqual(Expression.fromstring('x - (3-1)y + 3'), self.expr) self.assertEqual(Expression.fromstring('x - 2*y + 3'), self.expr) - def test_str(self): + def test_repr(self): self.assertEqual(str(Expression()), '0') self.assertEqual(str(self.x), 'x') self.assertEqual(str(-self.x), '-x') self.assertEqual(str(self.pi), '22/7') self.assertEqual(str(self.expr), 'x - 2*y + 3') - def test_repr(self): - self.assertEqual(repr(self.x), "Symbol('x')") - self.assertEqual(repr(self.one), 'Constant(1)') - self.assertEqual(repr(self.pi), 'Constant(22, 7)') - self.assertEqual(repr(self.x + self.one), "Expression('x + 1')") - self.assertEqual(repr(self.expr), "Expression('x - 2*y + 3')") - @requires_sympy def test_fromsympy(self): import sympy @@ -192,7 +194,8 @@ class TestSymbol(unittest.TestCase): def test_new(self): self.assertEqual(Symbol(' x '), self.x) - self.assertEqual(Symbol(self.x), self.x) + with self.assertRaises(TypeError): + Symbol(self.x) with self.assertRaises(TypeError): Symbol(1) @@ -207,11 +210,8 @@ class TestSymbol(unittest.TestCase): with self.assertRaises(SyntaxError): Symbol.fromstring('1') - def test_str(self): - self.assertEqual(str(self.x), 'x') - def test_repr(self): - self.assertEqual(repr(self.x), "Symbol('x')") + self.assertEqual(str(self.x), 'x') @requires_sympy def test_fromsympy(self): @@ -259,9 +259,9 @@ class TestConstant(unittest.TestCase): Constant.fromstring(1) def test_repr(self): - self.assertEqual(repr(self.zero), 'Constant(0)') - self.assertEqual(repr(self.one), 'Constant(1)') - self.assertEqual(repr(self.pi), 'Constant(22, 7)') + self.assertEqual(repr(self.zero), '0') + self.assertEqual(repr(self.one), '1') + self.assertEqual(repr(self.pi), '22/7') @requires_sympy def test_fromsympy(self): diff --git a/pypol/tests/test_polyhedra.py b/pypol/tests/test_polyhedra.py index c7a58a4..cb2015f 100644 --- a/pypol/tests/test_polyhedra.py +++ b/pypol/tests/test_polyhedra.py @@ -13,7 +13,7 @@ class TestPolyhedron(unittest.TestCase): self.square = Polyhedron(inequalities=[x, 1 - x, y, 1 - y]) def test_symbols(self): - self.assertCountEqual(self.square.symbols, ['x', 'y']) + self.assertCountEqual(self.square.symbols, symbols('x y')) def test_dimension(self): self.assertEqual(self.square.dimension, 2)