X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/7fff25bf40e4db570565586fb49165d1675002c2..ba88e64fedc541e9e3766e258a6cfa051acf56d9:/linpy/linexprs.py diff --git a/linpy/linexprs.py b/linpy/linexprs.py index bc36fda..b97f048 100644 --- a/linpy/linexprs.py +++ b/linpy/linexprs.py @@ -25,7 +25,7 @@ from fractions import Fraction, gcd __all__ = [ - 'Expression', + 'LinExpr', 'Symbol', 'Dummy', 'symbols', 'Rational', ] @@ -34,7 +34,7 @@ __all__ = [ def _polymorphic(func): @functools.wraps(func) def wrapper(left, right): - if isinstance(right, Expression): + if isinstance(right, LinExpr): return func(left, right) elif isinstance(right, numbers.Rational): right = Rational(right) @@ -43,19 +43,50 @@ def _polymorphic(func): return wrapper -class Expression: +class LinExpr: """ - This class implements linear expressions. + A linear expression consists of a list of coefficient-variable pairs + that capture the linear terms, plus a constant term. Linear expressions + are used to build constraints. They are temporary objects that typically + have short lifespans. + + Linear expressions are generally built using overloaded operators. For + example, if x is a Symbol, then x + 1 is an instance of LinExpr. + + LinExpr instances are hashable, and should be treated as immutable. """ def __new__(cls, coefficients=None, constant=0): """ - Create a new expression. + Return a linear expression from a dictionary or a sequence, that maps + symbols to their coefficients, and a constant term. The coefficients and + the constant term must be rational numbers. + + For example, the linear expression x + 2y + 1 can be constructed using + one of the following instructions: + + >>> x, y = symbols('x y') + >>> LinExpr({x: 1, y: 2}, 1) + >>> LinExpr([(x, 1), (y, 2)], 1) + + However, it may be easier to use overloaded operators: + + >>> x, y = symbols('x y') + >>> x + 2*y + 1 + + Alternatively, linear expressions can be constructed from a string: + + >>> LinExpr('x + 2*y + 1') + + A linear expression with a single symbol of coefficient 1 and no + constant term is automatically subclassed as a Symbol instance. A linear + expression with no symbol, only a constant term, is automatically + subclassed as a Rational instance. """ if isinstance(coefficients, str): if constant != 0: raise TypeError('too many arguments') - return Expression.fromstring(coefficients) + return LinExpr.fromstring(coefficients) if coefficients is None: return Rational(constant) if isinstance(coefficients, Mapping): @@ -86,39 +117,42 @@ class Expression: def coefficient(self, symbol): """ - Return the coefficient value of the given symbol. + Return the coefficient value of the given symbol, or 0 if the symbol + does not appear in the expression. """ if not isinstance(symbol, Symbol): raise TypeError('symbol must be a Symbol instance') - return Rational(self._coefficients.get(symbol, 0)) + return self._coefficients.get(symbol, Fraction(0)) __getitem__ = coefficient def coefficients(self): """ - Return a list of the coefficients of an expression + Iterate over the pairs (symbol, value) of linear terms in the + expression. The constant term is ignored. """ - for symbol, coefficient in self._coefficients.items(): - yield symbol, Rational(coefficient) + yield from self._coefficients.items() @property def constant(self): """ - Return the constant value of an expression. + The constant term of the expression. """ - return Rational(self._constant) + return self._constant @property def symbols(self): """ - Return a list of symbols in an expression. + The tuple of symbols present in the expression, sorted according to + Symbol.sortkey(). """ return self._symbols @property def dimension(self): """ - Create and return a new linear expression from a string or a list of coefficients and a constant. + The dimension of the expression, i.e. the number of symbols present in + it. """ return self._dimension @@ -127,23 +161,25 @@ class Expression: def isconstant(self): """ - Return true if an expression is a constant. + Return True if the expression only consists of a constant term. In this + case, it is a Rational instance. """ return False def issymbol(self): """ - Return true if an expression is a symbol. + Return True if an expression only consists of a symbol with coefficient + 1. In this case, it is a Symbol instance. """ return False def values(self): """ - Return the coefficient and constant values of an expression. + Iterate over the coefficient values in the expression, and the constant + term. """ - for coefficient in self._coefficients.values(): - yield Rational(coefficient) - yield Rational(self._constant) + yield from self._coefficients.values() + yield self._constant def __bool__(self): return True @@ -157,26 +193,26 @@ class Expression: @_polymorphic def __add__(self, other): """ - Return the sum of two expressions. + Return the sum of two linear expressions. """ coefficients = defaultdict(Fraction, self._coefficients) for symbol, coefficient in other._coefficients.items(): coefficients[symbol] += coefficient constant = self._constant + other._constant - return Expression(coefficients, constant) + return LinExpr(coefficients, constant) __radd__ = __add__ @_polymorphic def __sub__(self, other): """ - Return the difference between two expressions. + Return the difference between two linear expressions. """ coefficients = defaultdict(Fraction, self._coefficients) for symbol, coefficient in other._coefficients.items(): coefficients[symbol] -= coefficient constant = self._constant - other._constant - return Expression(coefficients, constant) + return LinExpr(coefficients, constant) @_polymorphic def __rsub__(self, other): @@ -184,33 +220,37 @@ class Expression: def __mul__(self, other): """ - Return the product of two expressions if other is a rational number. + Return the product of the linear expression by a rational. """ if isinstance(other, numbers.Rational): coefficients = ((symbol, coefficient * other) for symbol, coefficient in self._coefficients.items()) constant = self._constant * other - return Expression(coefficients, constant) + return LinExpr(coefficients, constant) return NotImplemented __rmul__ = __mul__ def __truediv__(self, other): + """ + Return the quotient of the linear expression by a rational. + """ if isinstance(other, numbers.Rational): coefficients = ((symbol, coefficient / other) for symbol, coefficient in self._coefficients.items()) constant = self._constant / other - return Expression(coefficients, constant) + return LinExpr(coefficients, constant) return NotImplemented @_polymorphic def __eq__(self, other): """ - Test whether two expressions are equal + Test whether two linear expressions are equal. """ - return isinstance(other, Expression) and \ - self._coefficients == other._coefficients and \ - self._constant == other._constant + if isinstance(other, LinExpr): + return self._coefficients == other._coefficients and \ + self._constant == other._constant + return NotImplemented def __le__(self, other): from .polyhedra import Le @@ -230,33 +270,40 @@ class Expression: def scaleint(self): """ - Multiply an expression by a scalar to make all coefficients integer values. + Return the expression multiplied by its lowest common denominator to + make all values integer. """ - lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), + lcd = functools.reduce(lambda a, b: a*b // gcd(a, b), [value.denominator for value in self.values()]) - return self * lcm + return self * lcd def subs(self, symbol, expression=None): """ - Subsitute symbol by expression in equations and return the resulting - expression. + Substitute the given symbol by an expression and return the resulting + expression. Raise TypeError if the resulting expression is not linear. + + >>> x, y = symbols('x y') + >>> e = x + 2*y + 1 + >>> e.subs(y, x - 1) + 3*x - 1 + + To perform multiple substitutions at once, pass a sequence or a + dictionary of (old, new) pairs to subs. + + >>> e.subs({x: y, y: x}) + 2*x + y + 1 """ if expression is None: - if isinstance(symbol, Mapping): - symbol = symbol.items() - substitutions = symbol + substitutions = dict(symbol) else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: + substitutions = {symbol: expression} + for symbol in substitutions: if not isinstance(symbol, Symbol): raise TypeError('symbols must be Symbol instances') - coefficients = [(othersymbol, coefficient) - for othersymbol, coefficient in result._coefficients.items() - if othersymbol != symbol] - coefficient = result._coefficients.get(symbol, 0) - constant = result._constant - result = Expression(coefficients, constant) + coefficient*expression + result = self._constant + for symbol, coefficient in self._coefficients.items(): + expression = substitutions.get(symbol, symbol) + result += coefficient * expression return result @classmethod @@ -284,17 +331,21 @@ class Expression: return left / right raise SyntaxError('invalid syntax') - _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') + _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()') @classmethod def fromstring(cls, string): """ - Create an expression from a string. + Create an expression from a string. Raise SyntaxError if the string is + not properly formatted. """ # add implicit multiplication operators, e.g. '5x' -> '5*x' - string = Expression._RE_NUM_VAR.sub(r'\1*\2', string) + string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string) tree = ast.parse(string, 'eval') - return cls._fromast(tree) + expr = cls._fromast(tree) + if not isinstance(expr, cls): + raise SyntaxError('invalid syntax') + return expr def __repr__(self): string = '' @@ -354,7 +405,8 @@ class Expression: @classmethod def fromsympy(cls, expr): """ - Convert sympy object to an expression. + Create a linear expression from a sympy expression. Raise TypeError is + the sympy expression is not linear. """ import sympy coefficients = [] @@ -363,16 +415,22 @@ class Expression: coefficient = Fraction(coefficient.p, coefficient.q) if symbol == sympy.S.One: constant = coefficient + elif isinstance(symbol, sympy.Dummy): + # we cannot properly convert dummy symbols + raise TypeError('cannot convert dummy symbols') elif isinstance(symbol, sympy.Symbol): symbol = Symbol(symbol.name) coefficients.append((symbol, coefficient)) else: - raise ValueError('non-linear expression: {!r}'.format(expr)) - return Expression(coefficients, constant) + raise TypeError('non-linear expression: {!r}'.format(expr)) + expr = LinExpr(coefficients, constant) + if not isinstance(expr, cls): + raise TypeError('cannot convert to a {} instance'.format(cls.__name__)) + return expr def tosympy(self): """ - Return an expression as a sympy object. + Convert the linear expression to a sympy expression. """ import sympy expr = 0 @@ -383,16 +441,28 @@ class Expression: return expr -class Symbol(Expression): +class Symbol(LinExpr): + """ + Symbols are the basic components to build expressions and constraints. + They correspond to mathematical variables. Symbols are instances of + class LinExpr and inherit its functionalities. + + Two instances of Symbol are equal if they have the same name. + """ def __new__(cls, name): """ - Create and return a symbol from a string. + Return a symbol with the name string given in argument. """ if not isinstance(name, str): raise TypeError('name must be a string') + node = ast.parse(name) + try: + name = node.body[0].value.id + except (AttributeError, SyntaxError): + raise SyntaxError('invalid syntax') self = object().__new__(cls) - self._name = name.strip() + self._name = name self._coefficients = {self: Fraction(1)} self._constant = Fraction(0) self._symbols = (self,) @@ -401,62 +471,85 @@ class Symbol(Expression): @property def name(self): + """ + The name of the symbol. + """ return self._name def __hash__(self): return hash(self.sortkey()) def sortkey(self): + """ + Return a sorting key for the symbol. It is useful to sort a list of + symbols in a consistent order, as comparison functions are overridden + (see the documentation of class LinExpr). + + >>> sort(symbols, key=Symbol.sortkey) + """ return self.name, def issymbol(self): return True def __eq__(self, other): - return self.sortkey() == other.sortkey() + if isinstance(other, Symbol): + return self.sortkey() == other.sortkey() + return NotImplemented def asdummy(self): """ - Return a symbol as a Dummy Symbol. + Return a new Dummy symbol instance with the same name. """ return Dummy(self.name) - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.Name): - return Symbol(node.id) - raise SyntaxError('invalid syntax') - def __repr__(self): return self.name def _repr_latex_(self): return '$${}$$'.format(self.name) - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Dummy): - return Dummy(expr.name) - elif isinstance(expr, sympy.Symbol): - return Symbol(expr.name) - else: - raise TypeError('expr must be a sympy.Symbol instance') + +def symbols(names): + """ + This function returns a tuple of symbols whose names are taken from a comma + or whitespace delimited string, or a sequence of strings. It is useful to + define several symbols at once. + + >>> x, y = symbols('x y') + >>> x, y = symbols('x, y') + >>> x, y = symbols(['x', 'y']) + """ + if isinstance(names, str): + names = names.replace(',', ' ').split() + return tuple(Symbol(name) for name in names) class Dummy(Symbol): """ - This class returns a dummy symbol to ensure that no variables are repeated in an expression + A variation of Symbol in which all symbols are unique and identified by + an internal count index. If a name is not supplied then a string value + of the count index will be used. This is useful when a unique, temporary + variable is needed and the name of the variable used in the expression + is not important. + + Unlike Symbol, Dummy instances with the same name are not equal: + + >>> x = Symbol('x') + >>> x1, x2 = Dummy('x'), Dummy('x') + >>> x == x1 + False + >>> x1 == x2 + False + >>> x1 == x1 + True """ + _count = 0 def __new__(cls, name=None): """ - Create and return a new dummy symbol. + Return a fresh dummy symbol with the name string given in argument. """ if name is None: name = 'Dummy_{}'.format(Dummy._count) @@ -485,18 +578,12 @@ class Dummy(Symbol): return '$${}_{{{}}}$$'.format(self.name, self._index) -def symbols(names): +class Rational(LinExpr, Fraction): """ - Transform strings into instances of the Symbol class - """ - if isinstance(names, str): - names = names.replace(',', ' ').split() - return tuple(Symbol(name) for name in names) - - -class Rational(Expression, Fraction): - """ - This class represents integers and rational numbers of any size. + A particular case of linear expressions are rational values, i.e. linear + expressions consisting only of a constant term, with no symbol. They are + implemented by the Rational class, that inherits from both LinExpr and + fractions.Fraction classes. """ def __new__(cls, numerator=0, denominator=None): @@ -514,15 +601,9 @@ class Rational(Expression, Fraction): @property def constant(self): - """ - Return rational as a constant. - """ return self def isconstant(self): - """ - Test whether a value is a constant. - """ return True def __bool__(self): @@ -543,16 +624,3 @@ class Rational(Expression, Fraction): else: return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator, self.denominator) - - @classmethod - def fromsympy(cls, expr): - """ - Create a rational object from a sympy expression - """ - import sympy - if isinstance(expr, sympy.Rational): - return Rational(expr.p, expr.q) - elif isinstance(expr, numbers.Rational): - return Rational(expr) - else: - raise TypeError('expr must be a sympy.Rational instance')