X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/2ffea1a47578a1b1d09906d57511062d68e6abea..7b93cea1daf2889e9ee10ca9c22a1b5124404937:/pypol/linexprs.py diff --git a/pypol/linexprs.py b/pypol/linexprs.py deleted file mode 100644 index b330045..0000000 --- a/pypol/linexprs.py +++ /dev/null @@ -1,514 +0,0 @@ -import ast -import functools -import numbers -import re - -from collections import OrderedDict -from fractions import Fraction, gcd - - -__all__ = [ - 'Expression', - 'Symbol', 'symbols', 'symbolname', 'symbolnames', - 'Constant', -] - - -def _polymorphic(func): - @functools.wraps(func) - def wrapper(left, right): - if isinstance(right, Expression): - return func(left, right) - elif isinstance(right, numbers.Rational): - right = Constant(right) - return func(left, right) - return NotImplemented - return wrapper - - -class Expression: - """ - This class implements linear expressions. - """ - - __slots__ = ( - '_coefficients', - '_constant', - '_symbols', - '_dimension', - '_hash', - ) - - def __new__(cls, coefficients=None, constant=0): - if isinstance(coefficients, str): - if constant: - raise TypeError('too many arguments') - return cls.fromstring(coefficients) - if isinstance(coefficients, dict): - coefficients = coefficients.items() - if coefficients is None: - return Constant(constant) - coefficients = [(symbol, coefficient) - for symbol, coefficient in coefficients if coefficient != 0] - if len(coefficients) == 0: - return Constant(constant) - elif len(coefficients) == 1 and constant == 0: - symbol, coefficient = coefficients[0] - if coefficient == 1: - return Symbol(symbol) - self = object().__new__(cls) - self._coefficients = {} - for symbol, coefficient in coefficients: - symbol = symbolname(symbol) - if isinstance(coefficient, Constant): - coefficient = coefficient.constant - if not isinstance(coefficient, numbers.Rational): - raise TypeError('coefficients must be rational numbers ' - 'or Constant instances') - self._coefficients[symbol] = coefficient - self._coefficients = OrderedDict(sorted(self._coefficients.items())) - if isinstance(constant, Constant): - constant = constant.constant - if not isinstance(constant, numbers.Rational): - raise TypeError('constant must be a rational number ' - 'or a Constant instance') - self._constant = constant - self._symbols = tuple(self._coefficients) - self._dimension = len(self._symbols) - self._hash = hash((tuple(self._coefficients.items()), self._constant)) - return self - - def coefficient(self, symbol): - symbol = symbolname(symbol) - try: - return self._coefficients[symbol] - except KeyError: - return 0 - - __getitem__ = coefficient - - def coefficients(self): - yield from self._coefficients.items() - - @property - def constant(self): - return self._constant - - @property - def symbols(self): - return self._symbols - - @property - def dimension(self): - return self._dimension - - def __hash__(self): - return self._hash - - def isconstant(self): - return False - - def issymbol(self): - return False - - def values(self): - for symbol in self.symbols: - yield self.coefficient(symbol) - yield self.constant - - def __bool__(self): - return True - - def __pos__(self): - return self - - def __neg__(self): - return self * -1 - - @_polymorphic - def __add__(self, other): - coefficients = dict(self.coefficients()) - for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] += coefficient - else: - coefficients[symbol] = coefficient - constant = self.constant + other.constant - return Expression(coefficients, constant) - - __radd__ = __add__ - - @_polymorphic - def __sub__(self, other): - coefficients = dict(self.coefficients()) - for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] -= coefficient - else: - coefficients[symbol] = -coefficient - constant = self.constant - other.constant - return Expression(coefficients, constant) - - def __rsub__(self, other): - return -(self - other) - - @_polymorphic - def __mul__(self, other): - if other.isconstant(): - coefficients = dict(self.coefficients()) - for symbol in coefficients: - coefficients[symbol] *= other.constant - constant = self.constant * other.constant - return Expression(coefficients, constant) - if isinstance(other, Expression) and not self.isconstant(): - raise ValueError('non-linear expression: ' - '{} * {}'.format(self._parenstr(), other._parenstr())) - return NotImplemented - - __rmul__ = __mul__ - - @_polymorphic - def __truediv__(self, other): - if other.isconstant(): - coefficients = dict(self.coefficients()) - for symbol in coefficients: - coefficients[symbol] = \ - Fraction(coefficients[symbol], other.constant) - constant = Fraction(self.constant, other.constant) - return Expression(coefficients, constant) - if isinstance(other, Expression): - raise ValueError('non-linear expression: ' - '{} / {}'.format(self._parenstr(), other._parenstr())) - return NotImplemented - - def __rtruediv__(self, other): - if isinstance(other, self): - if self.isconstant(): - constant = Fraction(other, self.constant) - return Expression(constant=constant) - else: - raise ValueError('non-linear expression: ' - '{} / {}'.format(other._parenstr(), self._parenstr())) - return NotImplemented - - @_polymorphic - def __eq__(self, other): - # "normal" equality - # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs - return isinstance(other, Expression) and \ - self._coefficients == other._coefficients and \ - self.constant == other.constant - - @_polymorphic - def __le__(self, other): - from .polyhedra import Le - return Le(self, other) - - @_polymorphic - def __lt__(self, other): - from .polyhedra import Lt - return Lt(self, other) - - @_polymorphic - def __ge__(self, other): - from .polyhedra import Ge - return Ge(self, other) - - @_polymorphic - def __gt__(self, other): - from .polyhedra import Gt - return Gt(self, other) - - def _toint(self): - lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), - [value.denominator for value in self.values()]) - return self * lcm - - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.Name): - return Symbol(node.id) - elif isinstance(node, ast.Num): - return Constant(node.n) - elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - return -cls._fromast(node.operand) - elif isinstance(node, ast.BinOp): - left = cls._fromast(node.left) - right = cls._fromast(node.right) - if isinstance(node.op, ast.Add): - return left + right - elif isinstance(node.op, ast.Sub): - return left - right - elif isinstance(node.op, ast.Mult): - return left * right - elif isinstance(node.op, ast.Div): - return left / right - raise SyntaxError('invalid syntax') - - def subs(self, symbol, expression=None): - if expression is None: - if isinstance(symbol, dict): - symbol = symbol.items() - substitutions = symbol - else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: - symbol = symbolname(symbol) - result = result._subs(symbol, expression) - return result - - def _subs(self, symbol, expression): - coefficients = {name: coefficient - for name, coefficient in self.coefficients() - if name != symbol} - constant = self.constant - coefficient = self.coefficient(symbol) - result = Expression(coefficients, self.constant) - result += coefficient * expression - return result - - _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') - - @classmethod - def fromstring(cls, string): - # add implicit multiplication operators, e.g. '5x' -> '5*x' - string = cls._RE_NUM_VAR.sub(r'\1*\2', string) - tree = ast.parse(string, 'eval') - return cls._fromast(tree) - - def __str__(self): - string = '' - i = 0 - for symbol in self.symbols: - coefficient = self.coefficient(symbol) - if coefficient == 1: - if i == 0: - string += symbol - else: - string += ' + {}'.format(symbol) - elif coefficient == -1: - if i == 0: - string += '-{}'.format(symbol) - else: - string += ' - {}'.format(symbol) - else: - if i == 0: - string += '{}*{}'.format(coefficient, symbol) - elif coefficient > 0: - string += ' + {}*{}'.format(coefficient, symbol) - else: - assert coefficient < 0 - coefficient *= -1 - string += ' - {}*{}'.format(coefficient, symbol) - i += 1 - constant = self.constant - if constant != 0 and i == 0: - string += '{}'.format(constant) - elif constant > 0: - string += ' + {}'.format(constant) - elif constant < 0: - constant *= -1 - string += ' - {}'.format(constant) - if string == '': - string = '0' - return string - - def _parenstr(self, always=False): - string = str(self) - if not always and (self.isconstant() or self.issymbol()): - return string - else: - return '({})'.format(string) - - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, str(self)) - - @classmethod - def fromsympy(cls, expr): - import sympy - coefficients = {} - constant = 0 - for symbol, coefficient in expr.as_coefficients_dict().items(): - coefficient = Fraction(coefficient.p, coefficient.q) - if symbol == sympy.S.One: - constant = coefficient - elif isinstance(symbol, sympy.Symbol): - symbol = symbol.name - coefficients[symbol] = coefficient - else: - raise ValueError('non-linear expression: {!r}'.format(expr)) - return cls(coefficients, constant) - - def tosympy(self): - import sympy - expr = 0 - for symbol, coefficient in self.coefficients(): - term = coefficient * sympy.Symbol(symbol) - expr += term - expr += self.constant - return expr - - -class Symbol(Expression): - - __slots__ = ( - '_name', - '_hash', - ) - - def __new__(cls, name): - name = symbolname(name) - self = object().__new__(cls) - self._name = name - self._hash = hash(self._name) - return self - - @property - def name(self): - return self._name - - def __hash__(self): - return self._hash - - def coefficient(self, symbol): - symbol = symbolname(symbol) - if symbol == self.name: - return 1 - else: - return 0 - - def coefficients(self): - yield self.name, 1 - - @property - def constant(self): - return 0 - - @property - def symbols(self): - return self.name, - - @property - def dimension(self): - return 1 - - def issymbol(self): - return True - - def __eq__(self, other): - return isinstance(other, Symbol) and self.name == other.name - - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.Name): - return Symbol(node.id) - raise SyntaxError('invalid syntax') - - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, self._name) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Symbol): - return cls(expr.name) - else: - raise TypeError('expr must be a sympy.Symbol instance') - - -def symbols(names): - if isinstance(names, str): - names = names.replace(',', ' ').split() - return (Symbol(name) for name in names) - -def symbolname(symbol): - if isinstance(symbol, str): - return symbol.strip() - elif isinstance(symbol, Symbol): - return symbol.name - else: - raise TypeError('symbol must be a string or a Symbol instance') - -def symbolnames(symbols): - if isinstance(symbols, str): - return symbols.replace(',', ' ').split() - return (symbolname(symbol) for symbol in symbols) - - -class Constant(Expression): - - __slots__ = ( - '_constant', - '_hash', - ) - - def __new__(cls, numerator=0, denominator=None): - self = object().__new__(cls) - if denominator is None and isinstance(numerator, Constant): - self._constant = numerator.constant - else: - self._constant = Fraction(numerator, denominator) - self._hash = hash(self._constant) - return self - - def __hash__(self): - return self._hash - - def coefficient(self, symbol): - symbol = symbolname(symbol) - return 0 - - def coefficients(self): - yield from [] - - @property - def symbols(self): - return () - - @property - def dimension(self): - return 0 - - def isconstant(self): - return True - - @_polymorphic - def __eq__(self, other): - return isinstance(other, Constant) and self.constant == other.constant - - def __bool__(self): - return self.constant != 0 - - @classmethod - def fromstring(cls, string): - if isinstance(string, str): - return Constant(Fraction(string)) - else: - raise TypeError('string must be a string instance') - - def __repr__(self): - if self.constant.denominator == 1: - return '{}({!r})'.format(self.__class__.__name__, - self.constant.numerator) - else: - return '{}({!r}, {!r})'.format(self.__class__.__name__, - self.constant.numerator, self.constant.denominator) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Rational): - return cls(expr.p, expr.q) - elif isinstance(expr, numbers.Rational): - return cls(expr) - else: - raise TypeError('expr must be a sympy.Rational instance')