X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/ba15f3f33f837b1291f74bc94081e99b860d3228..ba88e64fedc541e9e3766e258a6cfa051acf56d9:/linpy/linexprs.py diff --git a/linpy/linexprs.py b/linpy/linexprs.py index 82d75d0..b97f048 100644 --- a/linpy/linexprs.py +++ b/linpy/linexprs.py @@ -122,7 +122,7 @@ class LinExpr: """ if not isinstance(symbol, Symbol): raise TypeError('symbol must be a Symbol instance') - return Rational(self._coefficients.get(symbol, 0)) + return self._coefficients.get(symbol, Fraction(0)) __getitem__ = coefficient @@ -131,15 +131,14 @@ class LinExpr: Iterate over the pairs (symbol, value) of linear terms in the expression. The constant term is ignored. """ - for symbol, coefficient in self._coefficients.items(): - yield symbol, Rational(coefficient) + yield from self._coefficients.items() @property def constant(self): """ The constant term of the expression. """ - return Rational(self._constant) + return self._constant @property def symbols(self): @@ -179,9 +178,8 @@ class LinExpr: Iterate over the coefficient values in the expression, and the constant term. """ - for coefficient in self._coefficients.values(): - yield Rational(coefficient) - yield Rational(self._constant) + yield from self._coefficients.values() + yield self._constant def __bool__(self): return True @@ -249,9 +247,10 @@ class LinExpr: """ Test whether two linear expressions are equal. """ - return isinstance(other, LinExpr) and \ - self._coefficients == other._coefficients and \ - self._constant == other._constant + if isinstance(other, LinExpr): + return self._coefficients == other._coefficients and \ + self._constant == other._constant + return NotImplemented def __le__(self, other): from .polyhedra import Le @@ -274,9 +273,9 @@ class LinExpr: Return the expression multiplied by its lowest common denominator to make all values integer. """ - lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), + lcd = functools.reduce(lambda a, b: a*b // gcd(a, b), [value.denominator for value in self.values()]) - return self * lcm + return self * lcd def subs(self, symbol, expression=None): """ @@ -295,21 +294,16 @@ class LinExpr: 2*x + y + 1 """ if expression is None: - if isinstance(symbol, Mapping): - symbol = symbol.items() - substitutions = symbol + substitutions = dict(symbol) else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: + substitutions = {symbol: expression} + for symbol in substitutions: if not isinstance(symbol, Symbol): raise TypeError('symbols must be Symbol instances') - coefficients = [(othersymbol, coefficient) - for othersymbol, coefficient in result._coefficients.items() - if othersymbol != symbol] - coefficient = result._coefficients.get(symbol, 0) - constant = result._constant - result = LinExpr(coefficients, constant) + coefficient*expression + result = self._constant + for symbol, coefficient in self._coefficients.items(): + expression = substitutions.get(symbol, symbol) + result += coefficient * expression return result @classmethod @@ -337,7 +331,7 @@ class LinExpr: return left / right raise SyntaxError('invalid syntax') - _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') + _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()') @classmethod def fromstring(cls, string): @@ -348,7 +342,10 @@ class LinExpr: # add implicit multiplication operators, e.g. '5x' -> '5*x' string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string) tree = ast.parse(string, 'eval') - return cls._fromast(tree) + expr = cls._fromast(tree) + if not isinstance(expr, cls): + raise SyntaxError('invalid syntax') + return expr def __repr__(self): string = '' @@ -408,7 +405,7 @@ class LinExpr: @classmethod def fromsympy(cls, expr): """ - Create a linear expression from a sympy expression. Raise ValueError is + Create a linear expression from a sympy expression. Raise TypeError is the sympy expression is not linear. """ import sympy @@ -418,12 +415,18 @@ class LinExpr: coefficient = Fraction(coefficient.p, coefficient.q) if symbol == sympy.S.One: constant = coefficient + elif isinstance(symbol, sympy.Dummy): + # we cannot properly convert dummy symbols + raise TypeError('cannot convert dummy symbols') elif isinstance(symbol, sympy.Symbol): symbol = Symbol(symbol.name) coefficients.append((symbol, coefficient)) else: - raise ValueError('non-linear expression: {!r}'.format(expr)) - return LinExpr(coefficients, constant) + raise TypeError('non-linear expression: {!r}'.format(expr)) + expr = LinExpr(coefficients, constant) + if not isinstance(expr, cls): + raise TypeError('cannot convert to a {} instance'.format(cls.__name__)) + return expr def tosympy(self): """ @@ -453,8 +456,13 @@ class Symbol(LinExpr): """ if not isinstance(name, str): raise TypeError('name must be a string') + node = ast.parse(name) + try: + name = node.body[0].value.id + except (AttributeError, SyntaxError): + raise SyntaxError('invalid syntax') self = object().__new__(cls) - self._name = name.strip() + self._name = name self._coefficients = {self: Fraction(1)} self._constant = Fraction(0) self._symbols = (self,) @@ -485,7 +493,9 @@ class Symbol(LinExpr): return True def __eq__(self, other): - return self.sortkey() == other.sortkey() + if isinstance(other, Symbol): + return self.sortkey() == other.sortkey() + return NotImplemented def asdummy(self): """ @@ -493,31 +503,26 @@ class Symbol(LinExpr): """ return Dummy(self.name) - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.Name): - return Symbol(node.id) - raise SyntaxError('invalid syntax') - def __repr__(self): return self.name def _repr_latex_(self): return '$${}$$'.format(self.name) - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Dummy): - return Dummy(expr.name) - elif isinstance(expr, sympy.Symbol): - return Symbol(expr.name) - else: - raise TypeError('expr must be a sympy.Symbol instance') + +def symbols(names): + """ + This function returns a tuple of symbols whose names are taken from a comma + or whitespace delimited string, or a sequence of strings. It is useful to + define several symbols at once. + + >>> x, y = symbols('x y') + >>> x, y = symbols('x, y') + >>> x, y = symbols(['x', 'y']) + """ + if isinstance(names, str): + names = names.replace(',', ' ').split() + return tuple(Symbol(name) for name in names) class Dummy(Symbol): @@ -573,21 +578,6 @@ class Dummy(Symbol): return '$${}_{{{}}}$$'.format(self.name, self._index) -def symbols(names): - """ - This function returns a tuple of symbols whose names are taken from a comma - or whitespace delimited string, or a sequence of strings. It is useful to - define several symbols at once. - - >>> x, y = symbols('x y') - >>> x, y = symbols('x, y') - >>> x, y = symbols(['x', 'y']) - """ - if isinstance(names, str): - names = names.replace(',', ' ').split() - return tuple(Symbol(name) for name in names) - - class Rational(LinExpr, Fraction): """ A particular case of linear expressions are rational values, i.e. linear @@ -634,13 +624,3 @@ class Rational(LinExpr, Fraction): else: return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator, self.denominator) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Rational): - return Rational(expr.p, expr.q) - elif isinstance(expr, numbers.Rational): - return Rational(expr) - else: - raise TypeError('expr must be a sympy.Rational instance')