Fix Symbol == LinExpr comparisons
[linpy.git] / linpy / linexprs.py
index 99ad4ec..e4ed1cc 100644 (file)
@@ -131,8 +131,7 @@ class LinExpr:
         Iterate over the pairs (symbol, value) of linear terms in the
         expression. The constant term is ignored.
         """
         Iterate over the pairs (symbol, value) of linear terms in the
         expression. The constant term is ignored.
         """
-        for symbol, coefficient in self._coefficients.items():
-            yield symbol, coefficient
+        yield from self._coefficients.items()
 
     @property
     def constant(self):
 
     @property
     def constant(self):
@@ -179,8 +178,7 @@ class LinExpr:
         Iterate over the coefficient values in the expression, and the constant
         term.
         """
         Iterate over the coefficient values in the expression, and the constant
         term.
         """
-        for coefficient in self._coefficients.values():
-            yield coefficient
+        yield from self._coefficients.values()
         yield self._constant
 
     def __bool__(self):
         yield self._constant
 
     def __bool__(self):
@@ -249,9 +247,10 @@ class LinExpr:
         """
         Test whether two linear expressions are equal.
         """
         """
         Test whether two linear expressions are equal.
         """
-        return isinstance(other, LinExpr) and \
-            self._coefficients == other._coefficients and \
-            self._constant == other._constant
+        if isinstance(other, LinExpr):
+            return self._coefficients == other._coefficients and \
+                self._constant == other._constant
+        return NotImplemented
 
     def __le__(self, other):
         from .polyhedra import Le
 
     def __le__(self, other):
         from .polyhedra import Le
@@ -499,7 +498,9 @@ class Symbol(LinExpr):
         return True
 
     def __eq__(self, other):
         return True
 
     def __eq__(self, other):
-        return self.sortkey() == other.sortkey()
+        if isinstance(other, Symbol):
+            return self.sortkey() == other.sortkey()
+        return NotImplemented
 
     def asdummy(self):
         """
 
     def asdummy(self):
         """