X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/49fa0d574a20817712006d479549dd96ccfff652..7fda4e3865bb7cc9b6ee15cfd3c62c207f585ec7:/linpy/linexprs.py?ds=sidebyside diff --git a/linpy/linexprs.py b/linpy/linexprs.py index 38fa7a1..b5864e1 100644 --- a/linpy/linexprs.py +++ b/linpy/linexprs.py @@ -122,7 +122,7 @@ class LinExpr: """ if not isinstance(symbol, Symbol): raise TypeError('symbol must be a Symbol instance') - return Rational(self._coefficients.get(symbol, 0)) + return self._coefficients.get(symbol, Fraction(0)) __getitem__ = coefficient @@ -131,15 +131,14 @@ class LinExpr: Iterate over the pairs (symbol, value) of linear terms in the expression. The constant term is ignored. """ - for symbol, coefficient in self._coefficients.items(): - yield symbol, Rational(coefficient) + yield from self._coefficients.items() @property def constant(self): """ The constant term of the expression. """ - return Rational(self._constant) + return self._constant @property def symbols(self): @@ -179,9 +178,8 @@ class LinExpr: Iterate over the coefficient values in the expression, and the constant term. """ - for coefficient in self._coefficients.values(): - yield Rational(coefficient) - yield Rational(self._constant) + yield from self._coefficients.values() + yield self._constant def __bool__(self): return True @@ -249,9 +247,10 @@ class LinExpr: """ Test whether two linear expressions are equal. """ - return isinstance(other, LinExpr) and \ - self._coefficients == other._coefficients and \ - self._constant == other._constant + if isinstance(other, LinExpr): + return self._coefficients == other._coefficients and \ + self._constant == other._constant + return NotImplemented def __le__(self, other): from .polyhedra import Le @@ -274,9 +273,9 @@ class LinExpr: Return the expression multiplied by its lowest common denominator to make all values integer. """ - lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), + lcd = functools.reduce(lambda a, b: a*b // gcd(a, b), [value.denominator for value in self.values()]) - return self * lcm + return self * lcd def subs(self, symbol, expression=None): """ @@ -295,21 +294,16 @@ class LinExpr: 2*x + y + 1 """ if expression is None: - if isinstance(symbol, Mapping): - symbol = symbol.items() - substitutions = symbol + substitutions = dict(symbol) else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: + substitutions = {symbol: expression} + for symbol in substitutions: if not isinstance(symbol, Symbol): raise TypeError('symbols must be Symbol instances') - coefficients = [(othersymbol, coefficient) - for othersymbol, coefficient in result._coefficients.items() - if othersymbol != symbol] - coefficient = result._coefficients.get(symbol, 0) - constant = result._constant - result = LinExpr(coefficients, constant) + coefficient*expression + result = self._constant + for symbol, coefficient in self._coefficients.items(): + expression = substitutions.get(symbol, symbol) + result += coefficient * expression return result @classmethod @@ -499,7 +493,9 @@ class Symbol(LinExpr): return True def __eq__(self, other): - return self.sortkey() == other.sortkey() + if isinstance(other, Symbol): + return self.sortkey() == other.sortkey() + return NotImplemented def asdummy(self): """