+++ /dev/null
-import ast
-import functools
-import numbers
-import re
-
-from collections import OrderedDict, defaultdict, Mapping
-from fractions import Fraction, gcd
-
-
-__all__ = [
- 'Expression',
- 'Symbol', 'Dummy', 'symbols',
- 'Rational',
-]
-
-
-def _polymorphic(func):
- @functools.wraps(func)
- def wrapper(left, right):
- if isinstance(right, Expression):
- return func(left, right)
- elif isinstance(right, numbers.Rational):
- right = Rational(right)
- return func(left, right)
- return NotImplemented
- return wrapper
-
-
-class Expression:
- """
- This class implements linear expressions.
- """
-
- def __new__(cls, coefficients=None, constant=0):
- if isinstance(coefficients, str):
- if constant != 0:
- raise TypeError('too many arguments')
- return Expression.fromstring(coefficients)
- if coefficients is None:
- return Rational(constant)
- if isinstance(coefficients, Mapping):
- coefficients = coefficients.items()
- coefficients = list(coefficients)
- for symbol, coefficient in coefficients:
- if not isinstance(symbol, Symbol):
- raise TypeError('symbols must be Symbol instances')
- if not isinstance(coefficient, numbers.Rational):
- raise TypeError('coefficients must be rational numbers')
- if not isinstance(constant, numbers.Rational):
- raise TypeError('constant must be a rational number')
- if len(coefficients) == 0:
- return Rational(constant)
- if len(coefficients) == 1 and constant == 0:
- symbol, coefficient = coefficients[0]
- if coefficient == 1:
- return symbol
- coefficients = [(symbol, Fraction(coefficient))
- for symbol, coefficient in coefficients if coefficient != 0]
- coefficients.sort(key=lambda item: item[0].sortkey())
- self = object().__new__(cls)
- self._coefficients = OrderedDict(coefficients)
- self._constant = Fraction(constant)
- self._symbols = tuple(self._coefficients)
- self._dimension = len(self._symbols)
- return self
-
- def coefficient(self, symbol):
- if not isinstance(symbol, Symbol):
- raise TypeError('symbol must be a Symbol instance')
- return Rational(self._coefficients.get(symbol, 0))
-
- __getitem__ = coefficient
-
- def coefficients(self):
- for symbol, coefficient in self._coefficients.items():
- yield symbol, Rational(coefficient)
-
- @property
- def constant(self):
- return Rational(self._constant)
-
- @property
- def symbols(self):
- return self._symbols
-
- @property
- def dimension(self):
- return self._dimension
-
- def __hash__(self):
- return hash((tuple(self._coefficients.items()), self._constant))
-
- def isconstant(self):
- return False
-
- def issymbol(self):
- return False
-
- def values(self):
- for coefficient in self._coefficients.values():
- yield Rational(coefficient)
- yield Rational(self._constant)
-
- def __bool__(self):
- return True
-
- def __pos__(self):
- return self
-
- def __neg__(self):
- return self * -1
-
- @_polymorphic
- def __add__(self, other):
- coefficients = defaultdict(Fraction, self._coefficients)
- for symbol, coefficient in other._coefficients.items():
- coefficients[symbol] += coefficient
- constant = self._constant + other._constant
- return Expression(coefficients, constant)
-
- __radd__ = __add__
-
- @_polymorphic
- def __sub__(self, other):
- coefficients = defaultdict(Fraction, self._coefficients)
- for symbol, coefficient in other._coefficients.items():
- coefficients[symbol] -= coefficient
- constant = self._constant - other._constant
- return Expression(coefficients, constant)
-
- @_polymorphic
- def __rsub__(self, other):
- return other - self
-
- def __mul__(self, other):
- if isinstance(other, numbers.Rational):
- coefficients = ((symbol, coefficient * other)
- for symbol, coefficient in self._coefficients.items())
- constant = self._constant * other
- return Expression(coefficients, constant)
- return NotImplemented
-
- __rmul__ = __mul__
-
- def __truediv__(self, other):
- if isinstance(other, numbers.Rational):
- coefficients = ((symbol, coefficient / other)
- for symbol, coefficient in self._coefficients.items())
- constant = self._constant / other
- return Expression(coefficients, constant)
- return NotImplemented
-
- @_polymorphic
- def __eq__(self, other):
- # returns a boolean, not a constraint
- # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
- return isinstance(other, Expression) and \
- self._coefficients == other._coefficients and \
- self._constant == other._constant
-
- def __le__(self, other):
- from .polyhedra import Le
- return Le(self, other)
-
- def __lt__(self, other):
- from .polyhedra import Lt
- return Lt(self, other)
-
- def __ge__(self, other):
- from .polyhedra import Ge
- return Ge(self, other)
-
- def __gt__(self, other):
- from .polyhedra import Gt
- return Gt(self, other)
-
- def scaleint(self):
- lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
- [value.denominator for value in self.values()])
- return self * lcm
-
- def subs(self, symbol, expression=None):
- if expression is None:
- if isinstance(symbol, Mapping):
- symbol = symbol.items()
- substitutions = symbol
- else:
- substitutions = [(symbol, expression)]
- result = self
- for symbol, expression in substitutions:
- if not isinstance(symbol, Symbol):
- raise TypeError('symbols must be Symbol instances')
- coefficients = [(othersymbol, coefficient)
- for othersymbol, coefficient in result._coefficients.items()
- if othersymbol != symbol]
- coefficient = result._coefficients.get(symbol, 0)
- constant = result._constant
- result = Expression(coefficients, constant) + coefficient*expression
- return result
-
- @classmethod
- def _fromast(cls, node):
- if isinstance(node, ast.Module) and len(node.body) == 1:
- return cls._fromast(node.body[0])
- elif isinstance(node, ast.Expr):
- return cls._fromast(node.value)
- elif isinstance(node, ast.Name):
- return Symbol(node.id)
- elif isinstance(node, ast.Num):
- return Rational(node.n)
- elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
- return -cls._fromast(node.operand)
- elif isinstance(node, ast.BinOp):
- left = cls._fromast(node.left)
- right = cls._fromast(node.right)
- if isinstance(node.op, ast.Add):
- return left + right
- elif isinstance(node.op, ast.Sub):
- return left - right
- elif isinstance(node.op, ast.Mult):
- return left * right
- elif isinstance(node.op, ast.Div):
- return left / right
- raise SyntaxError('invalid syntax')
-
- _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
-
- @classmethod
- def fromstring(cls, string):
- # add implicit multiplication operators, e.g. '5x' -> '5*x'
- string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
- tree = ast.parse(string, 'eval')
- return cls._fromast(tree)
-
- def __repr__(self):
- string = ''
- for i, (symbol, coefficient) in enumerate(self.coefficients()):
- if coefficient == 1:
- if i != 0:
- string += ' + '
- elif coefficient == -1:
- string += '-' if i == 0 else ' - '
- elif i == 0:
- string += '{}*'.format(coefficient)
- elif coefficient > 0:
- string += ' + {}*'.format(coefficient)
- else:
- string += ' - {}*'.format(-coefficient)
- string += '{}'.format(symbol)
- constant = self.constant
- if len(string) == 0:
- string += '{}'.format(constant)
- elif constant > 0:
- string += ' + {}'.format(constant)
- elif constant < 0:
- string += ' - {}'.format(-constant)
- return string
-
- def _repr_latex_(self):
- string = ''
- for i, (symbol, coefficient) in enumerate(self.coefficients()):
- if coefficient == 1:
- if i != 0:
- string += ' + '
- elif coefficient == -1:
- string += '-' if i == 0 else ' - '
- elif i == 0:
- string += '{}'.format(coefficient._repr_latex_().strip('$'))
- elif coefficient > 0:
- string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
- elif coefficient < 0:
- string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
- string += '{}'.format(symbol._repr_latex_().strip('$'))
- constant = self.constant
- if len(string) == 0:
- string += '{}'.format(constant._repr_latex_().strip('$'))
- elif constant > 0:
- string += ' + {}'.format(constant._repr_latex_().strip('$'))
- elif constant < 0:
- string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
- return '$${}$$'.format(string)
-
- def _parenstr(self, always=False):
- string = str(self)
- if not always and (self.isconstant() or self.issymbol()):
- return string
- else:
- return '({})'.format(string)
-
- @classmethod
- def fromsympy(cls, expr):
- import sympy
- coefficients = []
- constant = 0
- for symbol, coefficient in expr.as_coefficients_dict().items():
- coefficient = Fraction(coefficient.p, coefficient.q)
- if symbol == sympy.S.One:
- constant = coefficient
- elif isinstance(symbol, sympy.Symbol):
- symbol = Symbol(symbol.name)
- coefficients.append((symbol, coefficient))
- else:
- raise ValueError('non-linear expression: {!r}'.format(expr))
- return Expression(coefficients, constant)
-
- def tosympy(self):
- import sympy
- expr = 0
- for symbol, coefficient in self.coefficients():
- term = coefficient * sympy.Symbol(symbol.name)
- expr += term
- expr += self.constant
- return expr
-
-
-class Symbol(Expression):
-
- def __new__(cls, name):
- if not isinstance(name, str):
- raise TypeError('name must be a string')
- self = object().__new__(cls)
- self._name = name.strip()
- self._coefficients = {self: Fraction(1)}
- self._constant = Fraction(0)
- self._symbols = (self,)
- self._dimension = 1
- return self
-
- @property
- def name(self):
- return self._name
-
- def __hash__(self):
- return hash(self.sortkey())
-
- def sortkey(self):
- return self.name,
-
- def issymbol(self):
- return True
-
- def __eq__(self, other):
- return self.sortkey() == other.sortkey()
-
- def asdummy(self):
- return Dummy(self.name)
-
- @classmethod
- def _fromast(cls, node):
- if isinstance(node, ast.Module) and len(node.body) == 1:
- return cls._fromast(node.body[0])
- elif isinstance(node, ast.Expr):
- return cls._fromast(node.value)
- elif isinstance(node, ast.Name):
- return Symbol(node.id)
- raise SyntaxError('invalid syntax')
-
- def __repr__(self):
- return self.name
-
- def _repr_latex_(self):
- return '$${}$$'.format(self.name)
-
- @classmethod
- def fromsympy(cls, expr):
- import sympy
- if isinstance(expr, sympy.Dummy):
- return Dummy(expr.name)
- elif isinstance(expr, sympy.Symbol):
- return Symbol(expr.name)
- else:
- raise TypeError('expr must be a sympy.Symbol instance')
-
-
-class Dummy(Symbol):
-
- _count = 0
-
- def __new__(cls, name=None):
- if name is None:
- name = 'Dummy_{}'.format(Dummy._count)
- elif not isinstance(name, str):
- raise TypeError('name must be a string')
- self = object().__new__(cls)
- self._index = Dummy._count
- self._name = name.strip()
- self._coefficients = {self: Fraction(1)}
- self._constant = Fraction(0)
- self._symbols = (self,)
- self._dimension = 1
- Dummy._count += 1
- return self
-
- def __hash__(self):
- return hash(self.sortkey())
-
- def sortkey(self):
- return self._name, self._index
-
- def __repr__(self):
- return '_{}'.format(self.name)
-
- def _repr_latex_(self):
- return '$${}_{{{}}}$$'.format(self.name, self._index)
-
-
-def symbols(names):
- if isinstance(names, str):
- names = names.replace(',', ' ').split()
- return tuple(Symbol(name) for name in names)
-
-
-class Rational(Expression, Fraction):
-
- def __new__(cls, numerator=0, denominator=None):
- self = Fraction.__new__(cls, numerator, denominator)
- self._coefficients = {}
- self._constant = Fraction(self)
- self._symbols = ()
- self._dimension = 0
- return self
-
- def __hash__(self):
- return Fraction.__hash__(self)
-
- @property
- def constant(self):
- return self
-
- def isconstant(self):
- return True
-
- def __bool__(self):
- return Fraction.__bool__(self)
-
- def __repr__(self):
- if self.denominator == 1:
- return '{!r}'.format(self.numerator)
- else:
- return '{!r}/{!r}'.format(self.numerator, self.denominator)
-
- def _repr_latex_(self):
- if self.denominator == 1:
- return '$${}$$'.format(self.numerator)
- elif self.numerator < 0:
- return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
- self.denominator)
- else:
- return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
- self.denominator)
-
- @classmethod
- def fromsympy(cls, expr):
- import sympy
- if isinstance(expr, sympy.Rational):
- return Rational(expr.p, expr.q)
- elif isinstance(expr, numbers.Rational):
- return Rational(expr)
- else:
- raise TypeError('expr must be a sympy.Rational instance')