X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/556abe7f3b2c7e3985560f3e3cfb6f66bacc4122..1154bf4ff8c2d7e7882703917a58d3a42995d78a:/pypol/linexprs.py diff --git a/pypol/linexprs.py b/pypol/linexprs.py index ed68493..10daf9d 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -249,6 +249,29 @@ class Expression: return left / right raise SyntaxError('invalid syntax') + def subs(self, symbol, expression=None): + if expression is None: + if isinstance(symbol, dict): + symbol = symbol.items() + substitutions = symbol + else: + substitutions = [(symbol, expression)] + result = self + for symbol, expression in substitutions: + symbol = symbolname(symbol) + result = result._subs(symbol, expression) + return result + + def _subs(self, symbol, expression): + coefficients = {name: coefficient + for name, coefficient in self.coefficients() + if name != symbol} + constant = self.constant + coefficient = self.coefficient(symbol) + result = Expression(coefficients, self.constant) + result += coefficient * expression + return result + _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') @classmethod @@ -418,7 +441,7 @@ def symbolname(symbol): def symbolnames(symbols): if isinstance(symbols, str): return symbols.replace(',', ' ').split() - return (symbolname(symbol) for symbol in symbols) + return tuple(symbolname(symbol) for symbol in symbols) class Constant(Expression):