Better implementation of Expression.__repr__
authorVivien Maisonneuve <v.maisonneuve@gmail.com>
Thu, 3 Jul 2014 14:10:35 +0000 (16:10 +0200)
committerVivien Maisonneuve <v.maisonneuve@gmail.com>
Thu, 3 Jul 2014 22:06:46 +0000 (00:06 +0200)
pypol/linexprs.py
pypol/tests/test_linexprs.py

index ef5d90b..a9f188b 100644 (file)
@@ -269,39 +269,27 @@ class Expression:
 
     def __repr__(self):
         string = ''
-        i = 0
-        for symbol in self.symbols:
-            coefficient = self.coefficient(symbol)
+        for i, (symbol, coefficient) in enumerate(self.coefficients()):
             if coefficient == 1:
-                if i == 0:
-                    string += symbol.name
-                else:
-                    string += ' + {}'.format(symbol)
+                string += '' if i == 0 else ' + '
+                string += '{!r}'.format(symbol)
             elif coefficient == -1:
-                if i == 0:
-                    string += '-{}'.format(symbol)
-                else:
-                    string += ' - {}'.format(symbol)
+                string += '-' if i == 0 else ' - '
+                string += '{!r}'.format(symbol)
             else:
                 if i == 0:
-                    string += '{}*{}'.format(coefficient, symbol)
+                    string += '{}*{!r}'.format(coefficient, symbol)
                 elif coefficient > 0:
-                    string += ' + {}*{}'.format(coefficient, symbol)
+                    string += ' + {}*{!r}'.format(coefficient, symbol)
                 else:
-                    assert coefficient < 0
-                    coefficient *= -1
-                    string += ' - {}*{}'.format(coefficient, symbol)
-            i += 1
+                    string += ' - {}*{!r}'.format(-coefficient, symbol)
         constant = self.constant
-        if constant != 0 and i == 0:
+        if len(string) == 0:
             string += '{}'.format(constant)
         elif constant > 0:
             string += ' + {}'.format(constant)
         elif constant < 0:
-            constant *= -1
-            string += ' - {}'.format(constant)
-        if string == '':
-            string = '0'
+            string += ' - {}'.format(-constant)
         return string
 
     def _parenstr(self, always=False):
@@ -406,11 +394,14 @@ class Symbol(Expression):
             return Symbol(node.id)
         raise SyntaxError('invalid syntax')
 
+    def __repr__(self):
+        return self.name
+
     @classmethod
     def fromsympy(cls, expr):
         import sympy
         if isinstance(expr, sympy.Symbol):
-            return Symbol(expr.name)
+            return cls(expr.name)
         else:
             raise TypeError('expr must be a sympy.Symbol instance')
 
@@ -442,6 +433,9 @@ class Dummy(Symbol):
     def __eq__(self, other):
         return isinstance(other, Dummy) and self._index == other._index
 
+    def __repr__(self):
+        return '_{}'.format(self.name)
+
 
 def symbols(names):
     if isinstance(names, str):
index 68cad74..5beaa17 100644 (file)
@@ -242,6 +242,13 @@ class TestDummy(unittest.TestCase):
         self.assertNotEqual(self.x, Dummy('x'))
         self.assertNotEqual(Dummy(), Dummy())
 
+    def test_repr(self):
+        self.assertEqual(repr(self.x), '_x')
+        dummy1 = Dummy()
+        dummy2 = Dummy()
+        self.assertTrue(repr(dummy1).startswith('_Dummy_'))
+        self.assertNotEqual(repr(dummy1), repr(dummy2))
+
 
 class TestSymbols(unittest.TestCase):