X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/de11b4e658edf3ea876aaea3d3d681eaec13dcc4..9431c353bd39a1fbb855580ee34931a07321a0f1:/pypol/linexprs.py?ds=sidebyside diff --git a/pypol/linexprs.py b/pypol/linexprs.py index ef5d90b..5ec5efd 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -3,7 +3,7 @@ import functools import numbers import re -from collections import OrderedDict, defaultdict +from collections import OrderedDict, defaultdict, Mapping from fractions import Fraction, gcd @@ -45,7 +45,7 @@ class Expression: return Expression.fromstring(coefficients) if coefficients is None: return Rational(constant) - if isinstance(coefficients, dict): + if isinstance(coefficients, Mapping): coefficients = coefficients.items() for symbol, coefficient in coefficients: if not isinstance(symbol, Symbol): @@ -218,7 +218,7 @@ class Expression: def subs(self, symbol, expression=None): if expression is None: - if isinstance(symbol, dict): + if isinstance(symbol, Mapping): symbol = symbol.items() substitutions = symbol else: @@ -269,39 +269,27 @@ class Expression: def __repr__(self): string = '' - i = 0 - for symbol in self.symbols: - coefficient = self.coefficient(symbol) + for i, (symbol, coefficient) in enumerate(self.coefficients()): if coefficient == 1: - if i == 0: - string += symbol.name - else: - string += ' + {}'.format(symbol) + string += '' if i == 0 else ' + ' + string += '{!r}'.format(symbol) elif coefficient == -1: - if i == 0: - string += '-{}'.format(symbol) - else: - string += ' - {}'.format(symbol) + string += '-' if i == 0 else ' - ' + string += '{!r}'.format(symbol) else: if i == 0: - string += '{}*{}'.format(coefficient, symbol) + string += '{}*{!r}'.format(coefficient, symbol) elif coefficient > 0: - string += ' + {}*{}'.format(coefficient, symbol) + string += ' + {}*{!r}'.format(coefficient, symbol) else: - assert coefficient < 0 - coefficient *= -1 - string += ' - {}*{}'.format(coefficient, symbol) - i += 1 + string += ' - {}*{!r}'.format(-coefficient, symbol) constant = self.constant - if constant != 0 and i == 0: + if len(string) == 0: string += '{}'.format(constant) elif constant > 0: string += ' + {}'.format(constant) elif constant < 0: - constant *= -1 - string += ' - {}'.format(constant) - if string == '': - string = '0' + string += ' - {}'.format(-constant) return string def _parenstr(self, always=False): @@ -406,11 +394,14 @@ class Symbol(Expression): return Symbol(node.id) raise SyntaxError('invalid syntax') + def __repr__(self): + return self.name + @classmethod def fromsympy(cls, expr): import sympy if isinstance(expr, sympy.Symbol): - return Symbol(expr.name) + return cls(expr.name) else: raise TypeError('expr must be a sympy.Symbol instance') @@ -442,6 +433,9 @@ class Dummy(Symbol): def __eq__(self, other): return isinstance(other, Dummy) and self._index == other._index + def __repr__(self): + return '_{}'.format(self.name) + def symbols(names): if isinstance(names, str):