From 49fa0d574a20817712006d479549dd96ccfff652 Mon Sep 17 00:00:00 2001 From: Vivien Maisonneuve Date: Tue, 19 Aug 2014 00:19:12 +0200 Subject: [PATCH] Symplify class verification in LinExpr.fromsympy() --- doc/reference.rst | 2 +- linpy/linexprs.py | 32 +++++++++----------------------- linpy/tests/test_linexprs.py | 2 +- 3 files changed, 11 insertions(+), 25 deletions(-) diff --git a/doc/reference.rst b/doc/reference.rst index e0efdde..de5300d 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -203,7 +203,7 @@ For example, if x is a :class:Symbol, then x + 1 is an instance of :cl .. classmethod:: fromsympy(expr) Create a linear expression from a :mod:sympy expression. - Raise :exc:ValueError is the :mod:sympy expression is not linear. + Raise :exc:TypeError is the :mod:sympy expression is not linear. .. method:: tosympy() diff --git a/linpy/linexprs.py b/linpy/linexprs.py index fce77a6..38fa7a1 100644 --- a/linpy/linexprs.py +++ b/linpy/linexprs.py @@ -411,7 +411,7 @@ class LinExpr: @classmethod def fromsympy(cls, expr): """ - Create a linear expression from a sympy expression. Raise ValueError is + Create a linear expression from a sympy expression. Raise TypeError is the sympy expression is not linear. """ import sympy @@ -421,12 +421,18 @@ class LinExpr: coefficient = Fraction(coefficient.p, coefficient.q) if symbol == sympy.S.One: constant = coefficient + elif isinstance(symbol, sympy.Dummy): + # we cannot properly convert dummy symbols + raise TypeError('cannot convert dummy symbols') elif isinstance(symbol, sympy.Symbol): symbol = Symbol(symbol.name) coefficients.append((symbol, coefficient)) else: - raise ValueError('non-linear expression: {!r}'.format(expr)) - return LinExpr(coefficients, constant) + raise TypeError('non-linear expression: {!r}'.format(expr)) + expr = LinExpr(coefficients, constant) + if not isinstance(expr, cls): + raise TypeError('cannot convert to a {} instance'.format(cls.__name__)) + return expr def tosympy(self): """ @@ -507,16 +513,6 @@ class Symbol(LinExpr): def _repr_latex_(self): return '$${}$$'.format(self.name) - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Dummy): - return Dummy(expr.name) - elif isinstance(expr, sympy.Symbol): - return Symbol(expr.name) - else: - raise TypeError('expr must be a sympy.Symbol instance') - def symbols(names): """ @@ -632,13 +628,3 @@ class Rational(LinExpr, Fraction): else: return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator, self.denominator) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Rational): - return Rational(expr.p, expr.q) - elif isinstance(expr, numbers.Rational): - return Rational(expr) - else: - raise TypeError('expr must be a sympy.Rational instance') diff --git a/linpy/tests/test_linexprs.py b/linpy/tests/test_linexprs.py index 9599d06..1b01186 100644 --- a/linpy/tests/test_linexprs.py +++ b/linpy/tests/test_linexprs.py @@ -195,7 +195,7 @@ class TestLinExpr(unittest.TestCase): self.assertEqual(LinExpr.fromsympy(sp_x), self.x) self.assertEqual(LinExpr.fromsympy(sympy.Rational(22, 7)), self.pi) self.assertEqual(LinExpr.fromsympy(sp_x - 2*sp_y + 3), self.expr) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): LinExpr.fromsympy(sp_x*sp_y) @requires_sympy -- 2.20.1