New method Expression.subs
authorVivien Maisonneuve <v.maisonneuve@gmail.com>
Wed, 2 Jul 2014 07:00:15 +0000 (09:00 +0200)
committerVivien Maisonneuve <v.maisonneuve@gmail.com>
Wed, 2 Jul 2014 07:00:15 +0000 (09:00 +0200)
pypol/_islhelper.c
pypol/linexprs.py
pypol/tests/test_linexprs.py

index bc62968..eaacc67 100644 (file)
@@ -36,7 +36,7 @@ static PyObject * isl_basic_set_constraints(PyObject *self, PyObject* args) {
         return NULL;
     }
     bset = (isl_basic_set *) ptr;
-    bset = isl_basic_set_finalize(bset);
+    bset = isl_basic_set_finalize(bset); // this instruction should not be required
     n = isl_basic_set_n_constraint(bset);
     if (n == -1) {
         PyErr_SetString(PyExc_RuntimeError,
index ed68493..b330045 100644 (file)
@@ -249,6 +249,29 @@ class Expression:
                 return left / right
         raise SyntaxError('invalid syntax')
 
+    def subs(self, symbol, expression=None):
+        if expression is None:
+            if isinstance(symbol, dict):
+                symbol = symbol.items()
+            substitutions = symbol
+        else:
+            substitutions = [(symbol, expression)]
+        result = self
+        for symbol, expression in substitutions:
+            symbol = symbolname(symbol)
+            result = result._subs(symbol, expression)
+        return result
+
+    def _subs(self, symbol, expression):
+        coefficients = {name: coefficient
+            for name, coefficient in self.coefficients()
+            if name != symbol}
+        constant = self.constant
+        coefficient = self.coefficient(symbol)
+        result = Expression(coefficients, self.constant)
+        result += coefficient * expression
+        return result
+
     _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
 
     @classmethod
index 1606ea0..5862351 100644 (file)
@@ -145,6 +145,18 @@ class TestExpression(unittest.TestCase):
         self.assertEqual((self.x + self.y/2 + self.z/3)._toint(),
                 6*self.x + 3*self.y + 2*self.z)
 
+    def test_subs(self):
+        self.assertEqual(self.x.subs('x', 3), 3)
+        self.assertEqual(self.x.subs('x', self.x), self.x)
+        self.assertEqual(self.x.subs('x', self.y), self.y)
+        self.assertEqual(self.x.subs('x', self.x + self.y), self.x + self.y)
+        self.assertEqual(self.x.subs('y', 3), self.x)
+        self.assertEqual(self.pi.subs('x', 3), self.pi)
+        self.assertEqual(self.expr.subs('x', -3), -2 * self.y)
+        self.assertEqual(self.expr.subs([('x', self.y), ('y', self.x)]), 3 - self.x)
+        self.assertEqual(self.expr.subs({'x': self.y, 'y': self.x}), 3 - self.x)
+        self.assertEqual(self.expr.subs({self.x: self.y, self.y: self.x}), 3 - self.x)
+
     def test_fromstring(self):
         self.assertEqual(Expression.fromstring('x'), self.x)
         self.assertEqual(Expression.fromstring('-x'), -self.x)