Improve comparison methods in LinExpr
[linpy.git] / examples / nsad2010.py
index e2cacc7..9359315 100755 (executable)
@@ -1,52 +1,46 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
-from pypol import *
-
-def affine_derivative_closure(T, x0s):
-
-    xs = [Symbol("{}'".format(x0.name)) for x0 in x0s]
-    dxs = [Symbol('d{}'.format(x0.name)) for x0 in x0s]
-    k = Symbol('k')
-
-    for x in T.symbols:
-        assert x in x0s + xs
-    for dx in dxs:
-        assert dx.name not in T.symbols
-    assert k.name not in T.symbols
-
-    T0 = T
-
-    T1 = T0
-    for i, x0 in enumerate(x0s):
-        x, dx = xs[i], dxs[i]
-        T1 &= Eq(dx, x - x0)
-
-    T2 = T1.project_out(T0.symbols)
-
-    T3_eqs = []
-    T3_ins = []
-    for T2_eq in T2.equalities:
-        c = T2_eq.constant
-        T3_eq = T2_eq + (k - 1) * c
-        T3_eqs.append(T3_eq)
-    for T2_in in T2.inequalities:
-        c = T2_in.constant
-        T3_in = T2_in + (k - 1) * c
-        T3_ins.append(T3_in)
-    T3 = Polyhedron(T3_eqs, T3_ins)
-    T3 &= Ge(k, 0)
-
-    T4 = T3.project_out([k])
-    for i, dx in enumerate(dxs):
-        x0, x = x0s[i], xs[i]
-        T4 &= Eq(dx, x - x0)
-    T4 = T4.project_out(dxs)
-
-    return T4
-
-i0, j0, i, j = symbols(['i', 'j', "i'", "j'"])
-T = Eq(i, i0 + 2) & Eq(j, j0 + 1)
-
-print('T =', T)
-Tstar = affine_derivative_closure(T, [i0, j0])
-print('T* =', Tstar)
+from linpy import *
+
+
+class Transformer:
+
+    def __new__(cls, polyhedron, range_symbols, domain_symbols):
+        self = object().__new__(cls)
+        self.polyhedron = polyhedron
+        self.range_symbols = range_symbols
+        self.domain_symbols = domain_symbols
+        return self
+
+    @property
+    def symbols(self):
+        return self.range_symbols + self.domain_symbols
+
+    def star(self):
+        delta_symbols = [symbol.asdummy() for symbol in self.range_symbols]
+        k = Dummy('k')
+        polyhedron = self.polyhedron
+        for x, xprime, dx in zip(self.range_symbols, self.domain_symbols, delta_symbols):
+            polyhedron &= Eq(dx, xprime - x)
+        polyhedron = polyhedron.project(self.symbols)
+        equalities, inequalities = [], []
+        for equality in polyhedron.equalities:
+            equality += (k-1) * equality.constant
+            equalities.append(equality)
+        for inequality in polyhedron.inequalities:
+            inequality += (k-1) * inequality.constant
+            inequalities.append(inequality)
+        polyhedron = Polyhedron(equalities, inequalities) & Ge(k, 0)
+        polyhedron = polyhedron.project([k])
+        for x, xprime, dx in zip(self.range_symbols, self.domain_symbols, delta_symbols):
+            polyhedron &= Eq(dx, xprime - x)
+        polyhedron = polyhedron.project(delta_symbols)
+        return Transformer(polyhedron, self.range_symbols, self.domain_symbols)
+
+
+if __name__ == '__main__':
+    i, iprime, j, jprime = symbols("i i' j j'")
+    transformer = Transformer(Eq(iprime, i + 2) & Eq(jprime, j + 1),
+        [i, j], [iprime, jprime])
+    print('T  =', transformer.polyhedron)
+    print('T* =', transformer.star().polyhedron)