X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/844f76f670a227d335e9bc539e1a61be7161452a..5516488d4ee3d277632ebbab6c93d45c3802c62e:/pypol/linear.py diff --git a/pypol/linear.py b/pypol/linear.py index 6550e92..a5f55fa 100644 --- a/pypol/linear.py +++ b/pypol/linear.py @@ -14,6 +14,31 @@ __all__ = [ ] +def _polymorphic_method(func): + @functools.wraps(func) + def wrapper(a, b): + if isinstance(b, Expression): + return func(a, b) + if isinstance(b, numbers.Rational): + b = constant(b) + return func(a, b) + return NotImplemented + return wrapper + +def _polymorphic_operator(func): + # A polymorphic operator should call a polymorphic method, hence we just + # have to test the left operand. + @functools.wraps(func) + def wrapper(a, b): + if isinstance(a, numbers.Rational): + a = constant(a) + return func(a, b) + elif isinstance(a, Expression): + return func(a, b) + raise TypeError('arguments must be linear expressions') + return wrapper + + class Expression: """ This class implements linear expressions. @@ -96,18 +121,7 @@ class Expression: def __neg__(self): return self * -1 - def _polymorphic(func): - @functools.wraps(func) - def wrapper(self, other): - if isinstance(other, Expression): - return func(self, other) - if isinstance(other, numbers.Rational): - other = Expression(constant=other) - return func(self, other) - return NotImplemented - return wrapper - - @_polymorphic + @_polymorphic_method def __add__(self, other): coefficients = dict(self.coefficients()) for symbol, coefficient in other.coefficients(): @@ -120,7 +134,7 @@ class Expression: __radd__ = __add__ - @_polymorphic + @_polymorphic_method def __sub__(self, other): coefficients = dict(self.coefficients()) for symbol, coefficient in other.coefficients(): @@ -133,7 +147,7 @@ class Expression: __rsub__ = __sub__ - @_polymorphic + @_polymorphic_method def __mul__(self, other): if other.isconstant(): coefficients = dict(self.coefficients()) @@ -148,7 +162,7 @@ class Expression: __rmul__ = __mul__ - @_polymorphic + @_polymorphic_method def __truediv__(self, other): if other.isconstant(): coefficients = dict(self.coefficients()) @@ -230,7 +244,7 @@ class Expression: def fromstring(cls, string): raise NotImplementedError - @_polymorphic + @_polymorphic_method def __eq__(self, other): # "normal" equality # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs @@ -246,29 +260,32 @@ class Expression: [value.denominator for value in self.values()]) return self * lcm - @_polymorphic + @_polymorphic_method def _eq(self, other): return Polyhedron(equalities=[(self - other)._canonify()]) - @_polymorphic + @_polymorphic_method def __le__(self, other): return Polyhedron(inequalities=[(self - other)._canonify()]) - @_polymorphic + @_polymorphic_method def __lt__(self, other): return Polyhedron(inequalities=[(self - other)._canonify() + 1]) - @_polymorphic + @_polymorphic_method def __ge__(self, other): return Polyhedron(inequalities=[(other - self)._canonify()]) - @_polymorphic + @_polymorphic_method def __gt__(self, other): return Polyhedron(inequalities=[(other - self)._canonify() + 1]) def constant(numerator=0, denominator=None): - return Expression(constant=Fraction(numerator, denominator)) + if denominator is None and isinstance(numerator, numbers.Rational): + return Expression(constant=numerator) + else: + return Expression(constant=Fraction(numerator, denominator)) def symbol(name): if not isinstance(name, str): @@ -281,35 +298,23 @@ def symbols(names): return (symbol(name) for name in names) -def _operator(func): - @functools.wraps(func) - def wrapper(a, b): - if isinstance(a, numbers.Rational): - a = constant(a) - if isinstance(b, numbers.Rational): - b = constant(b) - if isinstance(a, Expression) and isinstance(b, Expression): - return func(a, b) - raise TypeError('arguments must be linear expressions') - return wrapper - -@_operator +@_polymorphic_operator def eq(a, b): return a._eq(b) -@_operator +@_polymorphic_operator def le(a, b): return a <= b -@_operator +@_polymorphic_operator def lt(a, b): return a < b -@_operator +@_polymorphic_operator def ge(a, b): return a >= b -@_operator +@_polymorphic_operator def gt(a, b): return a > b