X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/960f0c252361dfd696359f803aae40a9b13b14a6..7b93cea1daf2889e9ee10ca9c22a1b5124404937:/pypol/linexprs.py diff --git a/pypol/linexprs.py b/pypol/linexprs.py deleted file mode 100644 index bd3ad5a..0000000 --- a/pypol/linexprs.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright 2014 MINES ParisTech -# -# This file is part of Linpy. -# -# Linpy is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Linpy is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Linpy. If not, see . - -import ast -import functools -import numbers -import re - -from collections import OrderedDict, defaultdict, Mapping -from fractions import Fraction, gcd - - -__all__ = [ - 'Expression', - 'Symbol', 'Dummy', 'symbols', - 'Rational', -] - - -def _polymorphic(func): - @functools.wraps(func) - def wrapper(left, right): - if isinstance(right, Expression): - return func(left, right) - elif isinstance(right, numbers.Rational): - right = Rational(right) - return func(left, right) - return NotImplemented - return wrapper - - -class Expression: - """ - This class implements linear expressions. - """ - - def __new__(cls, coefficients=None, constant=0): - if isinstance(coefficients, str): - if constant != 0: - raise TypeError('too many arguments') - return Expression.fromstring(coefficients) - if coefficients is None: - return Rational(constant) - if isinstance(coefficients, Mapping): - coefficients = coefficients.items() - coefficients = list(coefficients) - for symbol, coefficient in coefficients: - if not isinstance(symbol, Symbol): - raise TypeError('symbols must be Symbol instances') - if not isinstance(coefficient, numbers.Rational): - raise TypeError('coefficients must be rational numbers') - if not isinstance(constant, numbers.Rational): - raise TypeError('constant must be a rational number') - if len(coefficients) == 0: - return Rational(constant) - if len(coefficients) == 1 and constant == 0: - symbol, coefficient = coefficients[0] - if coefficient == 1: - return symbol - coefficients = [(symbol, Fraction(coefficient)) - for symbol, coefficient in coefficients if coefficient != 0] - coefficients.sort(key=lambda item: item[0].sortkey()) - self = object().__new__(cls) - self._coefficients = OrderedDict(coefficients) - self._constant = Fraction(constant) - self._symbols = tuple(self._coefficients) - self._dimension = len(self._symbols) - return self - - def coefficient(self, symbol): - if not isinstance(symbol, Symbol): - raise TypeError('symbol must be a Symbol instance') - return Rational(self._coefficients.get(symbol, 0)) - - __getitem__ = coefficient - - def coefficients(self): - for symbol, coefficient in self._coefficients.items(): - yield symbol, Rational(coefficient) - - @property - def constant(self): - return Rational(self._constant) - - @property - def symbols(self): - return self._symbols - - @property - def dimension(self): - return self._dimension - - def __hash__(self): - return hash((tuple(self._coefficients.items()), self._constant)) - - def isconstant(self): - return False - - def issymbol(self): - return False - - def values(self): - for coefficient in self._coefficients.values(): - yield Rational(coefficient) - yield Rational(self._constant) - - def __bool__(self): - return True - - def __pos__(self): - return self - - def __neg__(self): - return self * -1 - - @_polymorphic - def __add__(self, other): - coefficients = defaultdict(Fraction, self._coefficients) - for symbol, coefficient in other._coefficients.items(): - coefficients[symbol] += coefficient - constant = self._constant + other._constant - return Expression(coefficients, constant) - - __radd__ = __add__ - - @_polymorphic - def __sub__(self, other): - coefficients = defaultdict(Fraction, self._coefficients) - for symbol, coefficient in other._coefficients.items(): - coefficients[symbol] -= coefficient - constant = self._constant - other._constant - return Expression(coefficients, constant) - - @_polymorphic - def __rsub__(self, other): - return other - self - - def __mul__(self, other): - if isinstance(other, numbers.Rational): - coefficients = ((symbol, coefficient * other) - for symbol, coefficient in self._coefficients.items()) - constant = self._constant * other - return Expression(coefficients, constant) - return NotImplemented - - __rmul__ = __mul__ - - def __truediv__(self, other): - if isinstance(other, numbers.Rational): - coefficients = ((symbol, coefficient / other) - for symbol, coefficient in self._coefficients.items()) - constant = self._constant / other - return Expression(coefficients, constant) - return NotImplemented - - @_polymorphic - def __eq__(self, other): - # returns a boolean, not a constraint - # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs - return isinstance(other, Expression) and \ - self._coefficients == other._coefficients and \ - self._constant == other._constant - - def __le__(self, other): - from .polyhedra import Le - return Le(self, other) - - def __lt__(self, other): - from .polyhedra import Lt - return Lt(self, other) - - def __ge__(self, other): - from .polyhedra import Ge - return Ge(self, other) - - def __gt__(self, other): - from .polyhedra import Gt - return Gt(self, other) - - def scaleint(self): - lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), - [value.denominator for value in self.values()]) - return self * lcm - - def subs(self, symbol, expression=None): - if expression is None: - if isinstance(symbol, Mapping): - symbol = symbol.items() - substitutions = symbol - else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: - if not isinstance(symbol, Symbol): - raise TypeError('symbols must be Symbol instances') - coefficients = [(othersymbol, coefficient) - for othersymbol, coefficient in result._coefficients.items() - if othersymbol != symbol] - coefficient = result._coefficients.get(symbol, 0) - constant = result._constant - result = Expression(coefficients, constant) + coefficient*expression - return result - - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.Name): - return Symbol(node.id) - elif isinstance(node, ast.Num): - return Rational(node.n) - elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - return -cls._fromast(node.operand) - elif isinstance(node, ast.BinOp): - left = cls._fromast(node.left) - right = cls._fromast(node.right) - if isinstance(node.op, ast.Add): - return left + right - elif isinstance(node.op, ast.Sub): - return left - right - elif isinstance(node.op, ast.Mult): - return left * right - elif isinstance(node.op, ast.Div): - return left / right - raise SyntaxError('invalid syntax') - - _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') - - @classmethod - def fromstring(cls, string): - # add implicit multiplication operators, e.g. '5x' -> '5*x' - string = Expression._RE_NUM_VAR.sub(r'\1*\2', string) - tree = ast.parse(string, 'eval') - return cls._fromast(tree) - - def __repr__(self): - string = '' - for i, (symbol, coefficient) in enumerate(self.coefficients()): - if coefficient == 1: - if i != 0: - string += ' + ' - elif coefficient == -1: - string += '-' if i == 0 else ' - ' - elif i == 0: - string += '{}*'.format(coefficient) - elif coefficient > 0: - string += ' + {}*'.format(coefficient) - else: - string += ' - {}*'.format(-coefficient) - string += '{}'.format(symbol) - constant = self.constant - if len(string) == 0: - string += '{}'.format(constant) - elif constant > 0: - string += ' + {}'.format(constant) - elif constant < 0: - string += ' - {}'.format(-constant) - return string - - def _repr_latex_(self): - string = '' - for i, (symbol, coefficient) in enumerate(self.coefficients()): - if coefficient == 1: - if i != 0: - string += ' + ' - elif coefficient == -1: - string += '-' if i == 0 else ' - ' - elif i == 0: - string += '{}'.format(coefficient._repr_latex_().strip('$')) - elif coefficient > 0: - string += ' + {}'.format(coefficient._repr_latex_().strip('$')) - elif coefficient < 0: - string += ' - {}'.format((-coefficient)._repr_latex_().strip('$')) - string += '{}'.format(symbol._repr_latex_().strip('$')) - constant = self.constant - if len(string) == 0: - string += '{}'.format(constant._repr_latex_().strip('$')) - elif constant > 0: - string += ' + {}'.format(constant._repr_latex_().strip('$')) - elif constant < 0: - string += ' - {}'.format((-constant)._repr_latex_().strip('$')) - return '$${}$$'.format(string) - - def _parenstr(self, always=False): - string = str(self) - if not always and (self.isconstant() or self.issymbol()): - return string - else: - return '({})'.format(string) - - @classmethod - def fromsympy(cls, expr): - import sympy - coefficients = [] - constant = 0 - for symbol, coefficient in expr.as_coefficients_dict().items(): - coefficient = Fraction(coefficient.p, coefficient.q) - if symbol == sympy.S.One: - constant = coefficient - elif isinstance(symbol, sympy.Symbol): - symbol = Symbol(symbol.name) - coefficients.append((symbol, coefficient)) - else: - raise ValueError('non-linear expression: {!r}'.format(expr)) - return Expression(coefficients, constant) - - def tosympy(self): - import sympy - expr = 0 - for symbol, coefficient in self.coefficients(): - term = coefficient * sympy.Symbol(symbol.name) - expr += term - expr += self.constant - return expr - - -class Symbol(Expression): - - def __new__(cls, name): - if not isinstance(name, str): - raise TypeError('name must be a string') - self = object().__new__(cls) - self._name = name.strip() - self._coefficients = {self: Fraction(1)} - self._constant = Fraction(0) - self._symbols = (self,) - self._dimension = 1 - return self - - @property - def name(self): - return self._name - - def __hash__(self): - return hash(self.sortkey()) - - def sortkey(self): - return self.name, - - def issymbol(self): - return True - - def __eq__(self, other): - return self.sortkey() == other.sortkey() - - def asdummy(self): - return Dummy(self.name) - - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.Name): - return Symbol(node.id) - raise SyntaxError('invalid syntax') - - def __repr__(self): - return self.name - - def _repr_latex_(self): - return '$${}$$'.format(self.name) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Dummy): - return Dummy(expr.name) - elif isinstance(expr, sympy.Symbol): - return Symbol(expr.name) - else: - raise TypeError('expr must be a sympy.Symbol instance') - - -class Dummy(Symbol): - - _count = 0 - - def __new__(cls, name=None): - if name is None: - name = 'Dummy_{}'.format(Dummy._count) - elif not isinstance(name, str): - raise TypeError('name must be a string') - self = object().__new__(cls) - self._index = Dummy._count - self._name = name.strip() - self._coefficients = {self: Fraction(1)} - self._constant = Fraction(0) - self._symbols = (self,) - self._dimension = 1 - Dummy._count += 1 - return self - - def __hash__(self): - return hash(self.sortkey()) - - def sortkey(self): - return self._name, self._index - - def __repr__(self): - return '_{}'.format(self.name) - - def _repr_latex_(self): - return '$${}_{{{}}}$$'.format(self.name, self._index) - - -def symbols(names): - if isinstance(names, str): - names = names.replace(',', ' ').split() - return tuple(Symbol(name) for name in names) - - -class Rational(Expression, Fraction): - - def __new__(cls, numerator=0, denominator=None): - self = object().__new__(cls) - self._coefficients = {} - self._constant = Fraction(numerator, denominator) - self._symbols = () - self._dimension = 0 - self._numerator = self._constant.numerator - self._denominator = self._constant.denominator - return self - - def __hash__(self): - return Fraction.__hash__(self) - - @property - def constant(self): - return self - - def isconstant(self): - return True - - def __bool__(self): - return Fraction.__bool__(self) - - def __repr__(self): - if self.denominator == 1: - return '{!r}'.format(self.numerator) - else: - return '{!r}/{!r}'.format(self.numerator, self.denominator) - - def _repr_latex_(self): - if self.denominator == 1: - return '$${}$$'.format(self.numerator) - elif self.numerator < 0: - return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator, - self.denominator) - else: - return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator, - self.denominator) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Rational): - return Rational(expr.p, expr.q) - elif isinstance(expr, numbers.Rational): - return Rational(expr) - else: - raise TypeError('expr must be a sympy.Rational instance')