Symplify TypeError messages in Expression.__new__
[linpy.git] / pypol / linexprs.py
index 73c6b0e..b23eea8 100644 (file)
@@ -3,13 +3,13 @@ import functools
 import numbers
 import re
 
 import numbers
 import re
 
-from collections import OrderedDict, defaultdict
+from collections import OrderedDict, defaultdict, Mapping
 from fractions import Fraction, gcd
 
 
 __all__ = [
     'Expression',
 from fractions import Fraction, gcd
 
 
 __all__ = [
     'Expression',
-    'Symbol', 'symbols',
+    'Symbol', 'Dummy', 'symbols',
     'Rational',
 ]
 
     'Rational',
 ]
 
@@ -45,7 +45,7 @@ class Expression:
             return Expression.fromstring(coefficients)
         if coefficients is None:
             return Rational(constant)
             return Expression.fromstring(coefficients)
         if coefficients is None:
             return Rational(constant)
-        if isinstance(coefficients, dict):
+        if isinstance(coefficients, Mapping):
             coefficients = coefficients.items()
         for symbol, coefficient in coefficients:
             if not isinstance(symbol, Symbol):
             coefficients = coefficients.items()
         for symbol, coefficient in coefficients:
             if not isinstance(symbol, Symbol):
@@ -65,14 +65,12 @@ class Expression:
             if isinstance(coefficient, Rational):
                 coefficient = coefficient.constant
             if not isinstance(coefficient, numbers.Rational):
             if isinstance(coefficient, Rational):
                 coefficient = coefficient.constant
             if not isinstance(coefficient, numbers.Rational):
-                raise TypeError('coefficients must be rational numbers '
-                    'or Rational instances')
+                raise TypeError('coefficients must be Rational instances')
             self._coefficients[symbol] = coefficient
         if isinstance(constant, Rational):
             constant = constant.constant
         if not isinstance(constant, numbers.Rational):
             self._coefficients[symbol] = coefficient
         if isinstance(constant, Rational):
             constant = constant.constant
         if not isinstance(constant, numbers.Rational):
-            raise TypeError('constant must be a rational number '
-                'or a Rational instance')
+            raise TypeError('constant must be a Rational instance')
         self._constant = constant
         self._symbols = tuple(self._coefficients)
         self._dimension = len(self._symbols)
         self._constant = constant
         self._symbols = tuple(self._coefficients)
         self._dimension = len(self._symbols)
@@ -218,7 +216,7 @@ class Expression:
 
     def subs(self, symbol, expression=None):
         if expression is None:
 
     def subs(self, symbol, expression=None):
         if expression is None:
-            if isinstance(symbol, dict):
+            if isinstance(symbol, Mapping):
                 symbol = symbol.items()
             substitutions = symbol
         else:
                 symbol = symbol.items()
             substitutions = symbol
         else:
@@ -269,39 +267,27 @@ class Expression:
 
     def __repr__(self):
         string = ''
 
     def __repr__(self):
         string = ''
-        i = 0
-        for symbol in self.symbols:
-            coefficient = self.coefficient(symbol)
+        for i, (symbol, coefficient) in enumerate(self.coefficients()):
             if coefficient == 1:
             if coefficient == 1:
-                if i == 0:
-                    string += symbol.name
-                else:
-                    string += ' + {}'.format(symbol)
+                string += '' if i == 0 else ' + '
+                string += '{!r}'.format(symbol)
             elif coefficient == -1:
             elif coefficient == -1:
-                if i == 0:
-                    string += '-{}'.format(symbol)
-                else:
-                    string += ' - {}'.format(symbol)
+                string += '-' if i == 0 else ' - '
+                string += '{!r}'.format(symbol)
             else:
                 if i == 0:
             else:
                 if i == 0:
-                    string += '{}*{}'.format(coefficient, symbol)
+                    string += '{}*{!r}'.format(coefficient, symbol)
                 elif coefficient > 0:
                 elif coefficient > 0:
-                    string += ' + {}*{}'.format(coefficient, symbol)
+                    string += ' + {}*{!r}'.format(coefficient, symbol)
                 else:
                 else:
-                    assert coefficient < 0
-                    coefficient *= -1
-                    string += ' - {}*{}'.format(coefficient, symbol)
-            i += 1
+                    string += ' - {}*{!r}'.format(-coefficient, symbol)
         constant = self.constant
         constant = self.constant
-        if constant != 0 and i == 0:
+        if len(string) == 0:
             string += '{}'.format(constant)
         elif constant > 0:
             string += ' + {}'.format(constant)
         elif constant < 0:
             string += '{}'.format(constant)
         elif constant > 0:
             string += ' + {}'.format(constant)
         elif constant < 0:
-            constant *= -1
-            string += ' - {}'.format(constant)
-        if string == '':
-            string = '0'
+            string += ' - {}'.format(-constant)
         return string
 
     def _parenstr(self, always=False):
         return string
 
     def _parenstr(self, always=False):
@@ -355,7 +341,7 @@ class Symbol(Expression):
         return self._name
 
     def __hash__(self):
         return self._name
 
     def __hash__(self):
-        return hash(self._name)
+        return hash(self.sortkey())
 
     def coefficient(self, symbol):
         if not isinstance(symbol, Symbol):
 
     def coefficient(self, symbol):
         if not isinstance(symbol, Symbol):
@@ -390,7 +376,11 @@ class Symbol(Expression):
         yield 1
 
     def __eq__(self, other):
         yield 1
 
     def __eq__(self, other):
-        return isinstance(other, Symbol) and self.name == other.name
+        return not isinstance(other, Dummy) and isinstance(other, Symbol) \
+            and self.name == other.name
+
+    def asdummy(self):
+        return Dummy(self.name)
 
     @classmethod
     def _fromast(cls, node):
 
     @classmethod
     def _fromast(cls, node):
@@ -402,15 +392,49 @@ class Symbol(Expression):
             return Symbol(node.id)
         raise SyntaxError('invalid syntax')
 
             return Symbol(node.id)
         raise SyntaxError('invalid syntax')
 
+    def __repr__(self):
+        return self.name
+
     @classmethod
     def fromsympy(cls, expr):
         import sympy
         if isinstance(expr, sympy.Symbol):
     @classmethod
     def fromsympy(cls, expr):
         import sympy
         if isinstance(expr, sympy.Symbol):
-            return Symbol(expr.name)
+            return cls(expr.name)
         else:
             raise TypeError('expr must be a sympy.Symbol instance')
 
 
         else:
             raise TypeError('expr must be a sympy.Symbol instance')
 
 
+class Dummy(Symbol):
+
+    __slots__ = (
+        '_name',
+        '_index',
+    )
+
+    _count = 0
+
+    def __new__(cls, name=None):
+        if name is None:
+            name = 'Dummy_{}'.format(Dummy._count)
+        self = object().__new__(cls)
+        self._name = name.strip()
+        self._index = Dummy._count
+        Dummy._count += 1
+        return self
+
+    def __hash__(self):
+        return hash(self.sortkey())
+
+    def sortkey(self):
+        return self._name, self._index
+
+    def __eq__(self, other):
+        return isinstance(other, Dummy) and self._index == other._index
+
+    def __repr__(self):
+        return '_{}'.format(self.name)
+
+
 def symbols(names):
     if isinstance(names, str):
         names = names.replace(',', ' ').split()
 def symbols(names):
     if isinstance(names, str):
         names = names.replace(',', ' ').split()