author Vivien Maisonneuve Wed, 28 May 2014 16:27:07 +0000 (18:27 +0200) committer Vivien Maisonneuve Wed, 28 May 2014 16:27:07 +0000 (18:27 +0200)
 pypol/linear.py patch | blob | history

index 6550e92..fabf2a2 100644 (file)
@@ -14,6 +14,30 @@ __all__ = [
]

+def _polymorphic_method(func):
+    @functools.wraps(func)
+    def wrapper(a, b):
+        if isinstance(b, Expression):
+            return func(a, b)
+        if isinstance(b, numbers.Rational):
+            b = constant(b)
+            return func(a, b)
+        return NotImplemented
+    return wrapper
+
+def _polymorphic_operator(func):
+    @functools.wraps(func)
+    def wrapper(a, b):
+        if isinstance(a, numbers.Rational):
+            a = constant(a)
+        if isinstance(b, numbers.Rational):
+            b = constant(b)
+        if isinstance(a, Expression) and isinstance(b, Expression):
+            return func(a, b)
+        raise TypeError('arguments must be linear expressions')
+    return wrapper
+
+
class Expression:
"""
This class implements linear expressions.
@@ -96,18 +120,7 @@ class Expression:
def __neg__(self):
return self * -1

-    def _polymorphic(func):
-        @functools.wraps(func)
-        def wrapper(self, other):
-            if isinstance(other, Expression):
-                return func(self, other)
-            if isinstance(other, numbers.Rational):
-                other = Expression(constant=other)
-                return func(self, other)
-            return NotImplemented
-        return wrapper
-
-    @_polymorphic
+    @_polymorphic_method
coefficients = dict(self.coefficients())
for symbol, coefficient in other.coefficients():
@@ -120,7 +133,7 @@ class Expression:

-    @_polymorphic
+    @_polymorphic_method
def __sub__(self, other):
coefficients = dict(self.coefficients())
for symbol, coefficient in other.coefficients():
@@ -133,7 +146,7 @@ class Expression:

__rsub__ = __sub__

-    @_polymorphic
+    @_polymorphic_method
def __mul__(self, other):
if other.isconstant():
coefficients = dict(self.coefficients())
@@ -148,7 +161,7 @@ class Expression:

__rmul__ = __mul__

-    @_polymorphic
+    @_polymorphic_method
def __truediv__(self, other):
if other.isconstant():
coefficients = dict(self.coefficients())
@@ -230,7 +243,7 @@ class Expression:
def fromstring(cls, string):
raise NotImplementedError

-    @_polymorphic
+    @_polymorphic_method
def __eq__(self, other):
# "normal" equality
# see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
@@ -246,29 +259,32 @@ class Expression:
[value.denominator for value in self.values()])
return self * lcm

-    @_polymorphic
+    @_polymorphic_method
def _eq(self, other):
return Polyhedron(equalities=[(self - other)._canonify()])

-    @_polymorphic
+    @_polymorphic_method
def __le__(self, other):
return Polyhedron(inequalities=[(self - other)._canonify()])

-    @_polymorphic
+    @_polymorphic_method
def __lt__(self, other):
return Polyhedron(inequalities=[(self - other)._canonify() + 1])

-    @_polymorphic
+    @_polymorphic_method
def __ge__(self, other):
return Polyhedron(inequalities=[(other - self)._canonify()])

-    @_polymorphic
+    @_polymorphic_method
def __gt__(self, other):
return Polyhedron(inequalities=[(other - self)._canonify() + 1])

def constant(numerator=0, denominator=None):
-    return Expression(constant=Fraction(numerator, denominator))
+    if denominator is None and isinstance(numerator, numbers.Rational):
+        return Expression(constant=numerator)
+    else:
+        return Expression(constant=Fraction(numerator, denominator))

def symbol(name):
if not isinstance(name, str):
@@ -281,35 +297,23 @@ def symbols(names):
return (symbol(name) for name in names)

-def _operator(func):
-    @functools.wraps(func)
-    def wrapper(a, b):
-        if isinstance(a, numbers.Rational):
-            a = constant(a)
-        if isinstance(b, numbers.Rational):
-            b = constant(b)
-        if isinstance(a, Expression) and isinstance(b, Expression):
-            return func(a, b)
-        raise TypeError('arguments must be linear expressions')
-    return wrapper
-
-@_operator
+@_polymorphic_operator
def eq(a, b):
return a._eq(b)

-@_operator
+@_polymorphic_operator
def le(a, b):
return a <= b

-@_operator
+@_polymorphic_operator
def lt(a, b):
return a < b

-@_operator
+@_polymorphic_operator
def ge(a, b):
return a >= b

-@_operator
+@_polymorphic_operator
def gt(a, b):
return a > b