Fix Symbol == LinExpr comparisons
[linpy.git] / linpy / linexprs.py
index 834c3b4..e4ed1cc 100644 (file)
@@ -122,7 +122,7 @@ class LinExpr:
         """
         if not isinstance(symbol, Symbol):
             raise TypeError('symbol must be a Symbol instance')
         """
         if not isinstance(symbol, Symbol):
             raise TypeError('symbol must be a Symbol instance')
-        return Rational(self._coefficients.get(symbol, 0))
+        return self._coefficients.get(symbol, Fraction(0))
 
     __getitem__ = coefficient
 
 
     __getitem__ = coefficient
 
@@ -131,15 +131,14 @@ class LinExpr:
         Iterate over the pairs (symbol, value) of linear terms in the
         expression. The constant term is ignored.
         """
         Iterate over the pairs (symbol, value) of linear terms in the
         expression. The constant term is ignored.
         """
-        for symbol, coefficient in self._coefficients.items():
-            yield symbol, Rational(coefficient)
+        yield from self._coefficients.items()
 
     @property
     def constant(self):
         """
         The constant term of the expression.
         """
 
     @property
     def constant(self):
         """
         The constant term of the expression.
         """
-        return Rational(self._constant)
+        return self._constant
 
     @property
     def symbols(self):
 
     @property
     def symbols(self):
@@ -179,9 +178,8 @@ class LinExpr:
         Iterate over the coefficient values in the expression, and the constant
         term.
         """
         Iterate over the coefficient values in the expression, and the constant
         term.
         """
-        for coefficient in self._coefficients.values():
-            yield Rational(coefficient)
-        yield Rational(self._constant)
+        yield from self._coefficients.values()
+        yield self._constant
 
     def __bool__(self):
         return True
 
     def __bool__(self):
         return True
@@ -249,9 +247,10 @@ class LinExpr:
         """
         Test whether two linear expressions are equal.
         """
         """
         Test whether two linear expressions are equal.
         """
-        return isinstance(other, LinExpr) and \
-            self._coefficients == other._coefficients and \
-            self._constant == other._constant
+        if isinstance(other, LinExpr):
+            return self._coefficients == other._coefficients and \
+                self._constant == other._constant
+        return NotImplemented
 
     def __le__(self, other):
         from .polyhedra import Le
 
     def __le__(self, other):
         from .polyhedra import Le
@@ -411,7 +410,7 @@ class LinExpr:
     @classmethod
     def fromsympy(cls, expr):
         """
     @classmethod
     def fromsympy(cls, expr):
         """
-        Create a linear expression from a sympy expression. Raise ValueError is
+        Create a linear expression from a sympy expression. Raise TypeError is
         the sympy expression is not linear.
         """
         import sympy
         the sympy expression is not linear.
         """
         import sympy
@@ -421,12 +420,18 @@ class LinExpr:
             coefficient = Fraction(coefficient.p, coefficient.q)
             if symbol == sympy.S.One:
                 constant = coefficient
             coefficient = Fraction(coefficient.p, coefficient.q)
             if symbol == sympy.S.One:
                 constant = coefficient
+            elif isinstance(symbol, sympy.Dummy):
+                # we cannot properly convert dummy symbols
+                raise TypeError('cannot convert dummy symbols')
             elif isinstance(symbol, sympy.Symbol):
                 symbol = Symbol(symbol.name)
                 coefficients.append((symbol, coefficient))
             else:
             elif isinstance(symbol, sympy.Symbol):
                 symbol = Symbol(symbol.name)
                 coefficients.append((symbol, coefficient))
             else:
-                raise ValueError('non-linear expression: {!r}'.format(expr))
-        return LinExpr(coefficients, constant)
+                raise TypeError('non-linear expression: {!r}'.format(expr))
+        expr = LinExpr(coefficients, constant)
+        if not isinstance(expr, cls):
+            raise TypeError('cannot convert to a {} instance'.format(cls.__name__))
+        return expr
 
     def tosympy(self):
         """
 
     def tosympy(self):
         """
@@ -493,7 +498,9 @@ class Symbol(LinExpr):
         return True
 
     def __eq__(self, other):
         return True
 
     def __eq__(self, other):
-        return self.sortkey() == other.sortkey()
+        if isinstance(other, Symbol):
+            return self.sortkey() == other.sortkey()
+        return NotImplemented
 
     def asdummy(self):
         """
 
     def asdummy(self):
         """
@@ -507,15 +514,20 @@ class Symbol(LinExpr):
     def _repr_latex_(self):
         return '$${}$$'.format(self.name)
 
     def _repr_latex_(self):
         return '$${}$$'.format(self.name)
 
-    @classmethod
-    def fromsympy(cls, expr):
-        import sympy
-        if isinstance(expr, sympy.Dummy):
-            return Dummy(expr.name)
-        elif isinstance(expr, sympy.Symbol):
-            return Symbol(expr.name)
-        else:
-            raise TypeError('expr must be a sympy.Symbol instance')
+
+def symbols(names):
+    """
+    This function returns a tuple of symbols whose names are taken from a comma
+    or whitespace delimited string, or a sequence of strings. It is useful to
+    define several symbols at once.
+
+    >>> x, y = symbols('x y')
+    >>> x, y = symbols('x, y')
+    >>> x, y = symbols(['x', 'y'])
+    """
+    if isinstance(names, str):
+        names = names.replace(',', ' ').split()
+    return tuple(Symbol(name) for name in names)
 
 
 class Dummy(Symbol):
 
 
 class Dummy(Symbol):
@@ -571,21 +583,6 @@ class Dummy(Symbol):
         return '$${}_{{{}}}$$'.format(self.name, self._index)
 
 
         return '$${}_{{{}}}$$'.format(self.name, self._index)
 
 
-def symbols(names):
-    """
-    This function returns a tuple of symbols whose names are taken from a comma
-    or whitespace delimited string, or a sequence of strings. It is useful to
-    define several symbols at once.
-
-    >>> x, y = symbols('x y')
-    >>> x, y = symbols('x, y')
-    >>> x, y = symbols(['x', 'y'])
-    """
-    if isinstance(names, str):
-        names = names.replace(',', ' ').split()
-    return tuple(Symbol(name) for name in names)
-
-
 class Rational(LinExpr, Fraction):
     """
     A particular case of linear expressions are rational values, i.e. linear
 class Rational(LinExpr, Fraction):
     """
     A particular case of linear expressions are rational values, i.e. linear
@@ -632,13 +629,3 @@ class Rational(LinExpr, Fraction):
         else:
             return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
                 self.denominator)
         else:
             return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
                 self.denominator)
-
-    @classmethod
-    def fromsympy(cls, expr):
-        import sympy
-        if isinstance(expr, sympy.Rational):
-            return Rational(expr.p, expr.q)
-        elif isinstance(expr, numbers.Rational):
-            return Rational(expr)
-        else:
-            raise TypeError('expr must be a sympy.Rational instance')