index fce77a6..b97f048 100644 (file)
@@ -122,7 +122,7 @@ class LinExpr:
"""
if not isinstance(symbol, Symbol):
raise TypeError('symbol must be a Symbol instance')
"""
if not isinstance(symbol, Symbol):
raise TypeError('symbol must be a Symbol instance')
-        return Rational(self._coefficients.get(symbol, 0))
+        return self._coefficients.get(symbol, Fraction(0))

__getitem__ = coefficient

__getitem__ = coefficient

@@ -131,15 +131,14 @@ class LinExpr:
Iterate over the pairs (symbol, value) of linear terms in the
expression. The constant term is ignored.
"""
Iterate over the pairs (symbol, value) of linear terms in the
expression. The constant term is ignored.
"""
-        for symbol, coefficient in self._coefficients.items():
-            yield symbol, Rational(coefficient)
+        yield from self._coefficients.items()

@property
def constant(self):
"""
The constant term of the expression.
"""

@property
def constant(self):
"""
The constant term of the expression.
"""
-        return Rational(self._constant)
+        return self._constant

@property
def symbols(self):

@property
def symbols(self):
@@ -179,9 +178,8 @@ class LinExpr:
Iterate over the coefficient values in the expression, and the constant
term.
"""
Iterate over the coefficient values in the expression, and the constant
term.
"""
-        for coefficient in self._coefficients.values():
-            yield Rational(coefficient)
-        yield Rational(self._constant)
+        yield from self._coefficients.values()
+        yield self._constant

def __bool__(self):
return True

def __bool__(self):
return True
@@ -249,9 +247,10 @@ class LinExpr:
"""
Test whether two linear expressions are equal.
"""
"""
Test whether two linear expressions are equal.
"""
-        return isinstance(other, LinExpr) and \
-            self._coefficients == other._coefficients and \
-            self._constant == other._constant
+        if isinstance(other, LinExpr):
+            return self._coefficients == other._coefficients and \
+                self._constant == other._constant
+        return NotImplemented

def __le__(self, other):
from .polyhedra import Le

def __le__(self, other):
from .polyhedra import Le
@@ -274,9 +273,9 @@ class LinExpr:
Return the expression multiplied by its lowest common denominator to
make all values integer.
"""
Return the expression multiplied by its lowest common denominator to
make all values integer.
"""
-        lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
+        lcd = functools.reduce(lambda a, b: a*b // gcd(a, b),
[value.denominator for value in self.values()])
[value.denominator for value in self.values()])
-        return self * lcm
+        return self * lcd

def subs(self, symbol, expression=None):
"""

def subs(self, symbol, expression=None):
"""
@@ -295,21 +294,16 @@ class LinExpr:
2*x + y + 1
"""
if expression is None:
2*x + y + 1
"""
if expression is None:
-            if isinstance(symbol, Mapping):
-                symbol = symbol.items()
-            substitutions = symbol
+            substitutions = dict(symbol)
else:
else:
-            substitutions = [(symbol, expression)]
-        result = self
-        for symbol, expression in substitutions:
+            substitutions = {symbol: expression}
+        for symbol in substitutions:
if not isinstance(symbol, Symbol):
raise TypeError('symbols must be Symbol instances')
if not isinstance(symbol, Symbol):
raise TypeError('symbols must be Symbol instances')
-            coefficients = [(othersymbol, coefficient)
-                for othersymbol, coefficient in result._coefficients.items()
-                if othersymbol != symbol]
-            coefficient = result._coefficients.get(symbol, 0)
-            constant = result._constant
-            result = LinExpr(coefficients, constant) + coefficient*expression
+        result = self._constant
+        for symbol, coefficient in self._coefficients.items():
+            expression = substitutions.get(symbol, symbol)
+            result += coefficient * expression
return result

@classmethod
return result

@classmethod
@@ -337,7 +331,7 @@ class LinExpr:
return left / right
raise SyntaxError('invalid syntax')

return left / right
raise SyntaxError('invalid syntax')

-    _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|$$)') + _RE_NUM_VAR = re.compile(r'(\d+|$$)\s*([^\W\d]\w*|\()')

@classmethod
def fromstring(cls, string):

@classmethod
def fromstring(cls, string):
@@ -411,7 +405,7 @@ class LinExpr:
@classmethod
def fromsympy(cls, expr):
"""
@classmethod
def fromsympy(cls, expr):
"""
-        Create a linear expression from a sympy expression. Raise ValueError is
+        Create a linear expression from a sympy expression. Raise TypeError is
the sympy expression is not linear.
"""
import sympy
the sympy expression is not linear.
"""
import sympy
@@ -421,12 +415,18 @@ class LinExpr:
coefficient = Fraction(coefficient.p, coefficient.q)
if symbol == sympy.S.One:
constant = coefficient
coefficient = Fraction(coefficient.p, coefficient.q)
if symbol == sympy.S.One:
constant = coefficient
+            elif isinstance(symbol, sympy.Dummy):
+                # we cannot properly convert dummy symbols
+                raise TypeError('cannot convert dummy symbols')
elif isinstance(symbol, sympy.Symbol):
symbol = Symbol(symbol.name)
coefficients.append((symbol, coefficient))
else:
elif isinstance(symbol, sympy.Symbol):
symbol = Symbol(symbol.name)
coefficients.append((symbol, coefficient))
else:
-                raise ValueError('non-linear expression: {!r}'.format(expr))
-        return LinExpr(coefficients, constant)
+                raise TypeError('non-linear expression: {!r}'.format(expr))
+        expr = LinExpr(coefficients, constant)
+        if not isinstance(expr, cls):
+            raise TypeError('cannot convert to a {} instance'.format(cls.__name__))
+        return expr

def tosympy(self):
"""

def tosympy(self):
"""
@@ -493,7 +493,9 @@ class Symbol(LinExpr):
return True

def __eq__(self, other):
return True

def __eq__(self, other):
-        return self.sortkey() == other.sortkey()
+        if isinstance(other, Symbol):
+            return self.sortkey() == other.sortkey()
+        return NotImplemented

def asdummy(self):
"""

def asdummy(self):
"""
@@ -507,16 +509,6 @@ class Symbol(LinExpr):
def _repr_latex_(self):
return '$${}$$'.format(self.name)

def _repr_latex_(self):
return '$${}$$'.format(self.name)

-    @classmethod
-    def fromsympy(cls, expr):
-        import sympy
-        if isinstance(expr, sympy.Dummy):
-            return Dummy(expr.name)
-        elif isinstance(expr, sympy.Symbol):
-            return Symbol(expr.name)
-        else:
-            raise TypeError('expr must be a sympy.Symbol instance')
-

def symbols(names):
"""

def symbols(names):
"""
@@ -632,13 +624,3 @@ class Rational(LinExpr, Fraction):
else:
return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
self.denominator)
else:
return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
self.denominator)
-
-    @classmethod
-    def fromsympy(cls, expr):
-        import sympy
-        if isinstance(expr, sympy.Rational):
-            return Rational(expr.p, expr.q)
-        elif isinstance(expr, numbers.Rational):
-            return Rational(expr)
-        else:
-            raise TypeError('expr must be a sympy.Rational instance')