]
+def _polymorphic_method(func):
+ @functools.wraps(func)
+ def wrapper(a, b):
+ if isinstance(b, Expression):
+ return func(a, b)
+ if isinstance(b, numbers.Rational):
+ b = constant(b)
+ return func(a, b)
+ return NotImplemented
+ return wrapper
+
+def _polymorphic_operator(func):
+ # A polymorphic operator should call a polymorphic method, hence we just
+ # have to test the left operand.
+ @functools.wraps(func)
+ def wrapper(a, b):
+ if isinstance(a, numbers.Rational):
+ a = constant(a)
+ return func(a, b)
+ elif isinstance(a, Expression):
+ return func(a, b)
+ raise TypeError('arguments must be linear expressions')
+ return wrapper
+
+
class Expression:
"""
This class implements linear expressions.
def __neg__(self):
return self * -1
- def _polymorphic(func):
- @functools.wraps(func)
- def wrapper(self, other):
- if isinstance(other, Expression):
- return func(self, other)
- if isinstance(other, numbers.Rational):
- other = Expression(constant=other)
- return func(self, other)
- return NotImplemented
- return wrapper
-
- @_polymorphic
+ @_polymorphic_method
def __add__(self, other):
coefficients = dict(self.coefficients())
for symbol, coefficient in other.coefficients():
__radd__ = __add__
- @_polymorphic
+ @_polymorphic_method
def __sub__(self, other):
coefficients = dict(self.coefficients())
for symbol, coefficient in other.coefficients():
__rsub__ = __sub__
- @_polymorphic
+ @_polymorphic_method
def __mul__(self, other):
if other.isconstant():
coefficients = dict(self.coefficients())
__rmul__ = __mul__
- @_polymorphic
+ @_polymorphic_method
def __truediv__(self, other):
if other.isconstant():
coefficients = dict(self.coefficients())
def fromstring(cls, string):
raise NotImplementedError
- @_polymorphic
+ @_polymorphic_method
def __eq__(self, other):
# "normal" equality
# see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
[value.denominator for value in self.values()])
return self * lcm
- @_polymorphic
+ @_polymorphic_method
def _eq(self, other):
return Polyhedron(equalities=[(self - other)._canonify()])
- @_polymorphic
+ @_polymorphic_method
def __le__(self, other):
return Polyhedron(inequalities=[(self - other)._canonify()])
- @_polymorphic
+ @_polymorphic_method
def __lt__(self, other):
return Polyhedron(inequalities=[(self - other)._canonify() + 1])
- @_polymorphic
+ @_polymorphic_method
def __ge__(self, other):
return Polyhedron(inequalities=[(other - self)._canonify()])
- @_polymorphic
+ @_polymorphic_method
def __gt__(self, other):
return Polyhedron(inequalities=[(other - self)._canonify() + 1])
def constant(numerator=0, denominator=None):
- return Expression(constant=Fraction(numerator, denominator))
+ if denominator is None and isinstance(numerator, numbers.Rational):
+ return Expression(constant=numerator)
+ else:
+ return Expression(constant=Fraction(numerator, denominator))
def symbol(name):
if not isinstance(name, str):
return (symbol(name) for name in names)
-def _operator(func):
- @functools.wraps(func)
- def wrapper(a, b):
- if isinstance(a, numbers.Rational):
- a = constant(a)
- if isinstance(b, numbers.Rational):
- b = constant(b)
- if isinstance(a, Expression) and isinstance(b, Expression):
- return func(a, b)
- raise TypeError('arguments must be linear expressions')
- return wrapper
-
-@_operator
+@_polymorphic_operator
def eq(a, b):
return a._eq(b)
-@_operator
+@_polymorphic_operator
def le(a, b):
return a <= b
-@_operator
+@_polymorphic_operator
def lt(a, b):
return a < b
-@_operator
+@_polymorphic_operator
def ge(a, b):
return a >= b
-@_operator
+@_polymorphic_operator
def gt(a, b):
return a > b