Methods Expression.fromsympy(), Expression.tosympy()
authorVivien Maisonneuve <v.maisonneuve@gmail.com>
Mon, 23 Jun 2014 12:47:58 +0000 (14:47 +0200)
committerVivien Maisonneuve <v.maisonneuve@gmail.com>
Tue, 24 Jun 2014 08:42:12 +0000 (10:42 +0200)
pypol/linear.py
tests/test_linear.py

index eeee698..b69836d 100644 (file)
@@ -329,6 +329,31 @@ class Expression:
     def __gt__(self, other):
         return Polyhedron(inequalities=[(self - other)._toint() - 1])
 
+    @classmethod
+    def fromsympy(cls, expr):
+        import sympy
+        coefficients = {}
+        constant = 0
+        for symbol, coefficient in expr.as_coefficients_dict().items():
+            coefficient = Fraction(coefficient.p, coefficient.q)
+            if symbol == sympy.S.One:
+                constant = coefficient
+            elif isinstance(symbol, sympy.Symbol):
+                symbol = symbol.name
+                coefficients[symbol] = coefficient
+            else:
+                raise ValueError('non-linear expression: {!r}'.format(expr))
+        return cls(coefficients, constant)
+
+    def tosympy(self):
+        import sympy
+        expr = 0
+        for symbol, coefficient in self.coefficients():
+            term = coefficient * sympy.Symbol(symbol)
+            expr += term
+        expr += self.constant
+        return expr
+
 
 class Constant(Expression):
 
@@ -361,6 +386,17 @@ class Constant(Expression):
             return '{}({!r}, {!r})'.format(self.__class__.__name__,
                 self.constant.numerator, self.constant.denominator)
 
+    @classmethod
+    def fromsympy(cls, expr):
+        import sympy
+        if isinstance(expr, sympy.Rational):
+            return cls(expr.p, expr.q)
+        elif isinstance(expr, numbers.Rational):
+            return cls(expr)
+        else:
+            raise TypeError('expr must be a sympy.Rational instance')
+
+
 class Symbol(Expression):
 
     __slots__ = Expression.__slots__ + (
@@ -390,6 +426,15 @@ class Symbol(Expression):
     def __repr__(self):
         return '{}({!r})'.format(self.__class__.__name__, self._name)
 
+    @classmethod
+    def fromsympy(cls, expr):
+        import sympy
+        if isinstance(expr, sympy.Symbol):
+            return cls(expr.name)
+        else:
+            raise TypeError('expr must be a sympy.Symbol instance')
+
+
 def symbols(names):
     if isinstance(names, str):
         names = names.replace(',', ' ').split()
index 2375092..b7006a7 100644 (file)
@@ -151,10 +151,35 @@ class TestExpression(unittest.TestCase):
         self.assertEqual((self.x + self.y/2 + self.z/3)._toint(),
                 6*self.x + 3*self.y + 2*self.z)
 
+    def test_fromsympy(self):
+        import sympy
+        sp_x, sp_y = sympy.symbols('x y')
+        self.assertEqual(Expression.fromsympy(sp_x), self.x)
+        self.assertEqual(Expression.fromsympy(sympy.Rational(22, 7)), self.pi)
+        self.assertEqual(Expression.fromsympy(sp_x - 2*sp_y + 3), self.expr)
+        with self.assertRaises(ValueError):
+            Expression.fromsympy(sp_x*sp_y)
+
+    def test_tosympy(self):
+        import sympy
+        sp_x, sp_y = sympy.symbols('x y')
+        self.assertEqual(self.x.tosympy(), sp_x)
+        self.assertEqual(self.pi.tosympy(), sympy.Rational(22, 7))
+        self.assertEqual(self.expr.tosympy(), sp_x - 2*sp_y + 3)
+
 
 class TestConstant(unittest.TestCase):
 
-    pass
+    def setUp(self):
+        self.zero = Constant(0)
+        self.one = Constant(1)
+        self.pi = Constant(Fraction(22, 7))
+
+    def test_fromsympy(self):
+        import sympy
+        self.assertEqual(Constant.fromsympy(sympy.Rational(22, 7)), self.pi)
+        with self.assertRaises(TypeError):
+            Constant.fromsympy(sympy.Symbol('x'))
 
 
 class TestSymbol(unittest.TestCase):
@@ -171,6 +196,17 @@ class TestSymbol(unittest.TestCase):
         self.assertListEqual(list(symbols('x,y')), [self.x, self.y])
         self.assertListEqual(list(symbols(['x', 'y'])), [self.x, self.y])
 
+    def test_fromsympy(self):
+        import sympy
+        sp_x = sympy.Symbol('x')
+        self.assertEqual(Symbol.fromsympy(sp_x), self.x)
+        with self.assertRaises(TypeError):
+            Symbol.fromsympy(sympy.Rational(22, 7))
+        with self.assertRaises(TypeError):
+            Symbol.fromsympy(2 * sp_x)
+        with self.assertRaises(TypeError):
+            Symbol.fromsympy(sp_x*sp_x)
+
 
 class TestOperators(unittest.TestCase):