X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/b0a30ee84ef3c6e9505a8f872a260db373fdb641..1154bf4ff8c2d7e7882703917a58d3a42995d78a:/pypol/linexprs.py diff --git a/pypol/linexprs.py b/pypol/linexprs.py index 9ab5c86..10daf9d 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -3,12 +3,13 @@ import functools import numbers import re +from collections import OrderedDict from fractions import Fraction, gcd __all__ = [ 'Expression', - 'Symbol', 'symbols', + 'Symbol', 'symbols', 'symbolname', 'symbolnames', 'Constant', ] @@ -35,6 +36,7 @@ class Expression: '_constant', '_symbols', '_dimension', + '_hash', ) def __new__(cls, coefficients=None, constant=0): @@ -57,31 +59,27 @@ class Expression: self = object().__new__(cls) self._coefficients = {} for symbol, coefficient in coefficients: - if isinstance(symbol, Symbol): - symbol = symbol.name - elif not isinstance(symbol, str): - raise TypeError('symbols must be strings or Symbol instances') + symbol = symbolname(symbol) if isinstance(coefficient, Constant): coefficient = coefficient.constant if not isinstance(coefficient, numbers.Rational): raise TypeError('coefficients must be rational numbers ' 'or Constant instances') self._coefficients[symbol] = coefficient + self._coefficients = OrderedDict(sorted(self._coefficients.items())) if isinstance(constant, Constant): constant = constant.constant if not isinstance(constant, numbers.Rational): raise TypeError('constant must be a rational number ' 'or a Constant instance') self._constant = constant - self._symbols = tuple(sorted(self._coefficients)) + self._symbols = tuple(self._coefficients) self._dimension = len(self._symbols) + self._hash = hash((tuple(self._coefficients.items()), self._constant)) return self def coefficient(self, symbol): - if isinstance(symbol, Symbol): - symbol = str(symbol) - elif not isinstance(symbol, str): - raise TypeError('symbol must be a string or a Symbol instance') + symbol = symbolname(symbol) try: return self._coefficients[symbol] except KeyError: @@ -90,8 +88,7 @@ class Expression: __getitem__ = coefficient def coefficients(self): - for symbol in self.symbols: - yield symbol, self.coefficient(symbol) + yield from self._coefficients.items() @property def constant(self): @@ -105,6 +102,9 @@ class Expression: def dimension(self): return self._dimension + def __hash__(self): + return self._hash + def isconstant(self): return False @@ -219,9 +219,6 @@ class Expression: from .polyhedra import Gt return Gt(self, other) - def __hash__(self): - return hash((tuple(sorted(self._coefficients.items())), self._constant)) - def _toint(self): lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), [value.denominator for value in self.values()]) @@ -252,6 +249,29 @@ class Expression: return left / right raise SyntaxError('invalid syntax') + def subs(self, symbol, expression=None): + if expression is None: + if isinstance(symbol, dict): + symbol = symbol.items() + substitutions = symbol + else: + substitutions = [(symbol, expression)] + result = self + for symbol, expression in substitutions: + symbol = symbolname(symbol) + result = result._subs(symbol, expression) + return result + + def _subs(self, symbol, expression): + coefficients = {name: coefficient + for name, coefficient in self.coefficients() + if name != symbol} + constant = self.constant + coefficient = self.coefficient(symbol) + result = Expression(coefficients, self.constant) + result += coefficient * expression + return result + _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') @classmethod @@ -336,31 +356,53 @@ class Expression: class Symbol(Expression): - __slots__ = Expression.__slots__ + ( + __slots__ = ( '_name', + '_hash', ) def __new__(cls, name): - if isinstance(name, Symbol): - name = name.name - elif not isinstance(name, str): - raise TypeError('name must be a string or a Symbol instance') - name = name.strip() + name = symbolname(name) self = object().__new__(cls) - self._coefficients = {name: 1} - self._constant = 0 - self._symbols = tuple(name) self._name = name - self._dimension = 1 + self._hash = hash(self._name) return self @property def name(self): return self._name + def __hash__(self): + return self._hash + + def coefficient(self, symbol): + symbol = symbolname(symbol) + if symbol == self.name: + return 1 + else: + return 0 + + def coefficients(self): + yield self.name, 1 + + @property + def constant(self): + return 0 + + @property + def symbols(self): + return self.name, + + @property + def dimension(self): + return 1 + def issymbol(self): return True + def __eq__(self, other): + return isinstance(other, Symbol) and self.name == other.name + @classmethod def _fromast(cls, node): if isinstance(node, ast.Module) and len(node.body) == 1: @@ -388,23 +430,61 @@ def symbols(names): names = names.replace(',', ' ').split() return (Symbol(name) for name in names) +def symbolname(symbol): + if isinstance(symbol, str): + return symbol.strip() + elif isinstance(symbol, Symbol): + return symbol.name + else: + raise TypeError('symbol must be a string or a Symbol instance') + +def symbolnames(symbols): + if isinstance(symbols, str): + return symbols.replace(',', ' ').split() + return tuple(symbolname(symbol) for symbol in symbols) + class Constant(Expression): + __slots__ = ( + '_constant', + '_hash', + ) + def __new__(cls, numerator=0, denominator=None): self = object().__new__(cls) if denominator is None and isinstance(numerator, Constant): self._constant = numerator.constant else: self._constant = Fraction(numerator, denominator) - self._coefficients = {} - self._symbols = () - self._dimension = 0 + self._hash = hash(self._constant) return self + def __hash__(self): + return self._hash + + def coefficient(self, symbol): + symbol = symbolname(symbol) + return 0 + + def coefficients(self): + yield from [] + + @property + def symbols(self): + return () + + @property + def dimension(self): + return 0 + def isconstant(self): return True + @_polymorphic + def __eq__(self, other): + return isinstance(other, Constant) and self.constant == other.constant + def __bool__(self): return self.constant != 0