From: Vivien Maisonneuve Date: Wed, 2 Jul 2014 05:08:11 +0000 (+0200) Subject: Helper functions symbolname and symbolnames X-Git-Tag: 1.0~172 X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/commitdiff_plain/556abe7f3b2c7e3985560f3e3cfb6f66bacc4122 Helper functions symbolname and symbolnames --- diff --git a/pypol/domains.py b/pypol/domains.py index b4780fc..6060dc9 100644 --- a/pypol/domains.py +++ b/pypol/domains.py @@ -5,7 +5,7 @@ import re from . import islhelper from .islhelper import mainctx, libisl, isl_set_basic_sets -from .linexprs import Expression, Symbol +from .linexprs import Expression, Symbol, symbolnames __all__ = [ @@ -154,15 +154,7 @@ class Domain: def project_out(self, symbols): # use to remove certain variables - if isinstance(symbols, str): - symbols = symbols.replace(',', ' ').split() - else: - symbols = list(symbols) - for i, symbol in enumerate(symbols): - if isinstance(symbol, Symbol): - symbols[i] = symbol.name - elif not isinstance(symbol, str): - raise TypeError('symbols must be strings or Symbol instances') + symbols = symbolnames(symbols) islset = self._toislset(self.polyhedra, self.symbols) # the trick is to walk symbols in reverse order, to avoid index updates for index, symbol in reversed(list(enumerate(self.symbols))): @@ -346,6 +338,7 @@ class Domain: def tosympy(self): raise NotImplementedError + def And(*domains): if len(domains) == 0: from .polyhedra import Universe diff --git a/pypol/linexprs.py b/pypol/linexprs.py index 9449219..ed68493 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -9,7 +9,7 @@ from fractions import Fraction, gcd __all__ = [ 'Expression', - 'Symbol', 'symbols', + 'Symbol', 'symbols', 'symbolname', 'symbolnames', 'Constant', ] @@ -59,10 +59,7 @@ class Expression: self = object().__new__(cls) self._coefficients = {} for symbol, coefficient in coefficients: - if isinstance(symbol, Symbol): - symbol = symbol.name - elif not isinstance(symbol, str): - raise TypeError('symbols must be strings or Symbol instances') + symbol = symbolname(symbol) if isinstance(coefficient, Constant): coefficient = coefficient.constant if not isinstance(coefficient, numbers.Rational): @@ -82,10 +79,7 @@ class Expression: return self def coefficient(self, symbol): - if isinstance(symbol, Symbol): - symbol = symbol.name - elif not isinstance(symbol, str): - raise TypeError('symbol must be a string or a Symbol instance') + symbol = symbolname(symbol) try: return self._coefficients[symbol] except KeyError: @@ -345,11 +339,7 @@ class Symbol(Expression): ) def __new__(cls, name): - if isinstance(name, Symbol): - name = name.name - elif not isinstance(name, str): - raise TypeError('name must be a string or a Symbol instance') - name = name.strip() + name = symbolname(name) self = object().__new__(cls) self._name = name self._hash = hash(self._name) @@ -363,10 +353,7 @@ class Symbol(Expression): return self._hash def coefficient(self, symbol): - if isinstance(symbol, Symbol): - symbol = symbol.name - elif not isinstance(symbol, str): - raise TypeError('symbol must be a string or a Symbol instance') + symbol = symbolname(symbol) if symbol == self.name: return 1 else: @@ -420,6 +407,19 @@ def symbols(names): names = names.replace(',', ' ').split() return (Symbol(name) for name in names) +def symbolname(symbol): + if isinstance(symbol, str): + return symbol.strip() + elif isinstance(symbol, Symbol): + return symbol.name + else: + raise TypeError('symbol must be a string or a Symbol instance') + +def symbolnames(symbols): + if isinstance(symbols, str): + return symbols.replace(',', ' ').split() + return (symbolname(symbol) for symbol in symbols) + class Constant(Expression): @@ -441,10 +441,7 @@ class Constant(Expression): return self._hash def coefficient(self, symbol): - if isinstance(symbol, Symbol): - symbol = symbol.name - elif not isinstance(symbol, str): - raise TypeError('symbol must be a string or a Symbol instance') + symbol = symbolname(symbol) return 0 def coefficients(self):