X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/1d494bb187b70135df721c13306d7f26fdf33f50..2a1055d4f4754fa33c53d3f155cc050e4911a2a3:/pypol/linexprs.py?ds=inline diff --git a/pypol/linexprs.py b/pypol/linexprs.py index 0db7edd..c5f4336 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -3,13 +3,14 @@ import functools import numbers import re +from collections import OrderedDict, defaultdict, Mapping from fractions import Fraction, gcd __all__ = [ 'Expression', - 'Symbol', 'symbols', - 'Constant', + 'Symbol', 'Dummy', 'symbols', + 'Rational', ] @@ -19,7 +20,7 @@ def _polymorphic(func): if isinstance(right, Expression): return func(left, right) elif isinstance(right, numbers.Rational): - right = Constant(right) + right = Rational(right) return func(left, right) return NotImplemented return wrapper @@ -30,72 +31,53 @@ class Expression: This class implements linear expressions. """ - __slots__ = ( - '_coefficients', - '_constant', - '_symbols', - '_dimension', - ) - def __new__(cls, coefficients=None, constant=0): if isinstance(coefficients, str): - if constant: + if constant != 0: raise TypeError('too many arguments') - return cls.fromstring(coefficients) - if isinstance(coefficients, dict): - coefficients = coefficients.items() + return Expression.fromstring(coefficients) if coefficients is None: - return Constant(constant) - coefficients = [(symbol, coefficient) - for symbol, coefficient in coefficients if coefficient != 0] + return Rational(constant) + if isinstance(coefficients, Mapping): + coefficients = coefficients.items() + coefficients = list(coefficients) + for symbol, coefficient in coefficients: + if not isinstance(symbol, Symbol): + raise TypeError('symbols must be Symbol instances') + if not isinstance(coefficient, numbers.Rational): + raise TypeError('coefficients must be rational numbers') + if not isinstance(constant, numbers.Rational): + raise TypeError('constant must be a rational number') if len(coefficients) == 0: - return Constant(constant) - elif len(coefficients) == 1 and constant == 0: + return Rational(constant) + if len(coefficients) == 1 and constant == 0: symbol, coefficient = coefficients[0] if coefficient == 1: - return Symbol(symbol) + return symbol + coefficients = [(symbol, Fraction(coefficient)) + for symbol, coefficient in coefficients if coefficient != 0] + coefficients.sort(key=lambda item: item[0].sortkey()) self = object().__new__(cls) - self._coefficients = {} - for symbol, coefficient in coefficients: - if isinstance(symbol, Symbol): - symbol = symbol.name - elif not isinstance(symbol, str): - raise TypeError('symbols must be strings or Symbol instances') - if isinstance(coefficient, Constant): - coefficient = coefficient.constant - if not isinstance(coefficient, numbers.Rational): - raise TypeError('coefficients must be rational numbers ' - 'or Constant instances') - self._coefficients[symbol] = coefficient - if isinstance(constant, Constant): - constant = constant.constant - if not isinstance(constant, numbers.Rational): - raise TypeError('constant must be a rational number ' - 'or a Constant instance') - self._constant = constant - self._symbols = tuple(sorted(self._coefficients)) + self._coefficients = OrderedDict(coefficients) + self._constant = Fraction(constant) + self._symbols = tuple(self._coefficients) self._dimension = len(self._symbols) return self def coefficient(self, symbol): - if isinstance(symbol, Symbol): - symbol = str(symbol) - elif not isinstance(symbol, str): - raise TypeError('symbol must be a string or a Symbol instance') - try: - return self._coefficients[symbol] - except KeyError: - return 0 + if not isinstance(symbol, Symbol): + raise TypeError('symbol must be a Symbol instance') + return Rational(self._coefficients.get(symbol, 0)) __getitem__ = coefficient def coefficients(self): - for symbol in self.symbols: - yield symbol, self.coefficient(symbol) + for symbol, coefficient in self._coefficients.items(): + yield symbol, Rational(coefficient) @property def constant(self): - return self._constant + return Rational(self._constant) @property def symbols(self): @@ -105,6 +87,9 @@ class Expression: def dimension(self): return self._dimension + def __hash__(self): + return hash((tuple(self._coefficients.items()), self._constant)) + def isconstant(self): return False @@ -112,9 +97,9 @@ class Expression: return False def values(self): - for symbol in self.symbols: - yield self.coefficient(symbol) - yield self.constant + for coefficient in self._coefficients.values(): + yield Rational(coefficient) + yield Rational(self._constant) def __bool__(self): return True @@ -127,106 +112,92 @@ class Expression: @_polymorphic def __add__(self, other): - coefficients = dict(self.coefficients()) - for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] += coefficient - else: - coefficients[symbol] = coefficient - constant = self.constant + other.constant + coefficients = defaultdict(Fraction, self._coefficients) + for symbol, coefficient in other._coefficients.items(): + coefficients[symbol] += coefficient + constant = self._constant + other._constant return Expression(coefficients, constant) __radd__ = __add__ @_polymorphic def __sub__(self, other): - coefficients = dict(self.coefficients()) - for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] -= coefficient - else: - coefficients[symbol] = -coefficient - constant = self.constant - other.constant + coefficients = defaultdict(Fraction, self._coefficients) + for symbol, coefficient in other._coefficients.items(): + coefficients[symbol] -= coefficient + constant = self._constant - other._constant return Expression(coefficients, constant) + @_polymorphic def __rsub__(self, other): - return -(self - other) + return other - self - @_polymorphic def __mul__(self, other): - if other.isconstant(): - coefficients = dict(self.coefficients()) - for symbol in coefficients: - coefficients[symbol] *= other.constant - constant = self.constant * other.constant + if isinstance(other, numbers.Rational): + coefficients = ((symbol, coefficient * other) + for symbol, coefficient in self._coefficients.items()) + constant = self._constant * other return Expression(coefficients, constant) - if isinstance(other, Expression) and not self.isconstant(): - raise ValueError('non-linear expression: ' - '{} * {}'.format(self._parenstr(), other._parenstr())) return NotImplemented __rmul__ = __mul__ - @_polymorphic def __truediv__(self, other): - if other.isconstant(): - coefficients = dict(self.coefficients()) - for symbol in coefficients: - coefficients[symbol] = \ - Fraction(coefficients[symbol], other.constant) - constant = Fraction(self.constant, other.constant) + if isinstance(other, numbers.Rational): + coefficients = ((symbol, coefficient / other) + for symbol, coefficient in self._coefficients.items()) + constant = self._constant / other return Expression(coefficients, constant) - if isinstance(other, Expression): - raise ValueError('non-linear expression: ' - '{} / {}'.format(self._parenstr(), other._parenstr())) - return NotImplemented - - def __rtruediv__(self, other): - if isinstance(other, self): - if self.isconstant(): - constant = Fraction(other, self.constant) - return Expression(constant=constant) - else: - raise ValueError('non-linear expression: ' - '{} / {}'.format(other._parenstr(), self._parenstr())) return NotImplemented @_polymorphic def __eq__(self, other): - # "normal" equality + # returns a boolean, not a constraint # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs return isinstance(other, Expression) and \ - self._coefficients == other._coefficients and \ - self.constant == other.constant + self._coefficients == other._coefficients and \ + self._constant == other._constant - @_polymorphic def __le__(self, other): from .polyhedra import Le return Le(self, other) - @_polymorphic def __lt__(self, other): from .polyhedra import Lt return Lt(self, other) - @_polymorphic def __ge__(self, other): from .polyhedra import Ge return Ge(self, other) - @_polymorphic def __gt__(self, other): from .polyhedra import Gt return Gt(self, other) - def __hash__(self): - return hash((tuple(sorted(self._coefficients.items())), self._constant)) - - def _toint(self): + def scaleint(self): lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), [value.denominator for value in self.values()]) return self * lcm + def subs(self, symbol, expression=None): + if expression is None: + if isinstance(symbol, Mapping): + symbol = symbol.items() + substitutions = symbol + else: + substitutions = [(symbol, expression)] + result = self + for symbol, expression in substitutions: + if not isinstance(symbol, Symbol): + raise TypeError('symbols must be Symbol instances') + coefficients = [(othersymbol, coefficient) + for othersymbol, coefficient in result._coefficients.items() + if othersymbol != symbol] + coefficient = result._coefficients.get(symbol, 0) + constant = result._constant + result = Expression(coefficients, constant) + coefficient*expression + return result + @classmethod def _fromast(cls, node): if isinstance(node, ast.Module) and len(node.body) == 1: @@ -236,7 +207,7 @@ class Expression: elif isinstance(node, ast.Name): return Symbol(node.id) elif isinstance(node, ast.Num): - return Constant(node.n) + return Rational(node.n) elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): return -cls._fromast(node.operand) elif isinstance(node, ast.BinOp): @@ -252,49 +223,63 @@ class Expression: return left / right raise SyntaxError('invalid syntax') + _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') + @classmethod def fromstring(cls, string): - string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string) + # add implicit multiplication operators, e.g. '5x' -> '5*x' + string = Expression._RE_NUM_VAR.sub(r'\1*\2', string) tree = ast.parse(string, 'eval') return cls._fromast(tree) - def __str__(self): + def __repr__(self): string = '' - i = 0 - for symbol in self.symbols: - coefficient = self.coefficient(symbol) + for i, (symbol, coefficient) in enumerate(self.coefficients()): if coefficient == 1: - if i == 0: - string += symbol - else: - string += ' + {}'.format(symbol) + if i != 0: + string += ' + ' elif coefficient == -1: - if i == 0: - string += '-{}'.format(symbol) - else: - string += ' - {}'.format(symbol) + string += '-' if i == 0 else ' - ' + elif i == 0: + string += '{}*'.format(coefficient) + elif coefficient > 0: + string += ' + {}*'.format(coefficient) else: - if i == 0: - string += '{}*{}'.format(coefficient, symbol) - elif coefficient > 0: - string += ' + {}*{}'.format(coefficient, symbol) - else: - assert coefficient < 0 - coefficient *= -1 - string += ' - {}*{}'.format(coefficient, symbol) - i += 1 + string += ' - {}*'.format(-coefficient) + string += '{}'.format(symbol) constant = self.constant - if constant != 0 and i == 0: + if len(string) == 0: string += '{}'.format(constant) elif constant > 0: string += ' + {}'.format(constant) elif constant < 0: - constant *= -1 - string += ' - {}'.format(constant) - if string == '': - string = '0' + string += ' - {}'.format(-constant) return string + def _repr_latex_(self): + string = '' + for i, (symbol, coefficient) in enumerate(self.coefficients()): + if coefficient == 1: + if i != 0: + string += ' + ' + elif coefficient == -1: + string += '-' if i == 0 else ' - ' + elif i == 0: + string += '{}'.format(coefficient._repr_latex_().strip('$')) + elif coefficient > 0: + string += ' + {}'.format(coefficient._repr_latex_().strip('$')) + elif coefficient < 0: + string += ' - {}'.format((-coefficient)._repr_latex_().strip('$')) + string += '{}'.format(symbol._repr_latex_().strip('$')) + constant = self.constant + if len(string) == 0: + string += '{}'.format(constant._repr_latex_().strip('$')) + elif constant > 0: + string += ' + {}'.format(constant._repr_latex_().strip('$')) + elif constant < 0: + string += ' - {}'.format((-constant)._repr_latex_().strip('$')) + return '${}$'.format(string) + def _parenstr(self, always=False): string = str(self) if not always and (self.isconstant() or self.issymbol()): @@ -302,30 +287,27 @@ class Expression: else: return '({})'.format(string) - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, str(self)) - @classmethod def fromsympy(cls, expr): import sympy - coefficients = {} + coefficients = [] constant = 0 for symbol, coefficient in expr.as_coefficients_dict().items(): coefficient = Fraction(coefficient.p, coefficient.q) if symbol == sympy.S.One: constant = coefficient elif isinstance(symbol, sympy.Symbol): - symbol = symbol.name - coefficients[symbol] = coefficient + symbol = Symbol(symbol.name) + coefficients.append((symbol, coefficient)) else: raise ValueError('non-linear expression: {!r}'.format(expr)) - return cls(coefficients, constant) + return Expression(coefficients, constant) def tosympy(self): import sympy expr = 0 for symbol, coefficient in self.coefficients(): - term = coefficient * sympy.Symbol(symbol) + term = coefficient * sympy.Symbol(symbol.name) expr += term expr += self.constant return expr @@ -333,21 +315,14 @@ class Expression: class Symbol(Expression): - __slots__ = Expression.__slots__ + ( - '_name', - ) - def __new__(cls, name): - if isinstance(name, Symbol): - name = name.name - elif not isinstance(name, str): - raise TypeError('name must be a string or a Symbol instance') - name = name.strip() + if not isinstance(name, str): + raise TypeError('name must be a string') self = object().__new__(cls) - self._coefficients = {name: 1} - self._constant = 0 - self._symbols = tuple(name) - self._name = name + self._name = name.strip() + self._coefficients = {self: Fraction(1)} + self._constant = Fraction(0) + self._symbols = (self,) self._dimension = 1 return self @@ -355,9 +330,21 @@ class Symbol(Expression): def name(self): return self._name + def __hash__(self): + return hash(self.sortkey()) + + def sortkey(self): + return self.name, + def issymbol(self): return True + def __eq__(self, other): + return self.sortkey() == other.sortkey() + + def asdummy(self): + return Dummy(self.name) + @classmethod def _fromast(cls, node): if isinstance(node, ast.Module) and len(node.body) == 1: @@ -369,63 +356,105 @@ class Symbol(Expression): raise SyntaxError('invalid syntax') def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, self._name) + return self.name + + def _repr_latex_(self): + return '${}$'.format(self.name) @classmethod def fromsympy(cls, expr): import sympy - if isinstance(expr, sympy.Symbol): - return cls(expr.name) + if isinstance(expr, sympy.Dummy): + return Dummy(expr.name) + elif isinstance(expr, sympy.Symbol): + return Symbol(expr.name) else: raise TypeError('expr must be a sympy.Symbol instance') +class Dummy(Symbol): + + _count = 0 + + def __new__(cls, name=None): + if name is None: + name = 'Dummy_{}'.format(Dummy._count) + elif not isinstance(name, str): + raise TypeError('name must be a string') + self = object().__new__(cls) + self._index = Dummy._count + self._name = name.strip() + self._coefficients = {self: Fraction(1)} + self._constant = Fraction(0) + self._symbols = (self,) + self._dimension = 1 + Dummy._count += 1 + return self + + def __hash__(self): + return hash(self.sortkey()) + + def sortkey(self): + return self._name, self._index + + def __repr__(self): + return '_{}'.format(self.name) + + def _repr_latex_(self): + return '${}_{{{}}}$'.format(self.name, self._index) + + def symbols(names): if isinstance(names, str): names = names.replace(',', ' ').split() - return (Symbol(name) for name in names) + return tuple(Symbol(name) for name in names) -class Constant(Expression): +class Rational(Expression, Fraction): def __new__(cls, numerator=0, denominator=None): - self = object().__new__(cls) - if denominator is None and isinstance(numerator, Constant): - self._constant = numerator.constant - else: - self._constant = Fraction(numerator, denominator) + self = Fraction.__new__(cls, numerator, denominator) self._coefficients = {} + self._constant = Fraction(self) self._symbols = () self._dimension = 0 return self + def __hash__(self): + return Fraction.__hash__(self) + + @property + def constant(self): + return self + def isconstant(self): return True def __bool__(self): - return self.constant != 0 - - @classmethod - def fromstring(cls, string): - if isinstance(string, str): - return Constant(Fraction(string)) - else: - raise TypeError('string must be a string instance') + return Fraction.__bool__(self) def __repr__(self): - if self.constant.denominator == 1: - return '{}({!r})'.format(self.__class__.__name__, - self.constant.numerator) + if self.denominator == 1: + return '{!r}'.format(self.numerator) + else: + return '{!r}/{!r}'.format(self.numerator, self.denominator) + + def _repr_latex_(self): + if self.denominator == 1: + return '${}$'.format(self.numerator) + elif self.numerator < 0: + return '$-\\frac{{{}}}{{{}}}$'.format(-self.numerator, + self.denominator) else: - return '{}({!r}, {!r})'.format(self.__class__.__name__, - self.constant.numerator, self.constant.denominator) + return '$\\frac{{{}}}{{{}}}$'.format(self.numerator, + self.denominator) @classmethod def fromsympy(cls, expr): import sympy if isinstance(expr, sympy.Rational): - return cls(expr.p, expr.q) + return Rational(expr.p, expr.q) elif isinstance(expr, numbers.Rational): - return cls(expr) + return Rational(expr) else: raise TypeError('expr must be a sympy.Rational instance')