__all__ = [
'Expression',
- 'Symbol', 'symbols',
+ 'Symbol', 'symbols', 'symbolname', 'symbolnames',
'Constant',
]
self = object().__new__(cls)
self._coefficients = {}
for symbol, coefficient in coefficients:
- if isinstance(symbol, Symbol):
- symbol = symbol.name
- elif not isinstance(symbol, str):
- raise TypeError('symbols must be strings or Symbol instances')
+ symbol = symbolname(symbol)
if isinstance(coefficient, Constant):
coefficient = coefficient.constant
if not isinstance(coefficient, numbers.Rational):
return self
def coefficient(self, symbol):
- if isinstance(symbol, Symbol):
- symbol = symbol.name
- elif not isinstance(symbol, str):
- raise TypeError('symbol must be a string or a Symbol instance')
+ symbol = symbolname(symbol)
try:
return self._coefficients[symbol]
except KeyError:
)
def __new__(cls, name):
- if isinstance(name, Symbol):
- name = name.name
- elif not isinstance(name, str):
- raise TypeError('name must be a string or a Symbol instance')
- name = name.strip()
+ name = symbolname(name)
self = object().__new__(cls)
self._name = name
self._hash = hash(self._name)
return self._hash
def coefficient(self, symbol):
- if isinstance(symbol, Symbol):
- symbol = symbol.name
- elif not isinstance(symbol, str):
- raise TypeError('symbol must be a string or a Symbol instance')
+ symbol = symbolname(symbol)
if symbol == self.name:
return 1
else:
names = names.replace(',', ' ').split()
return (Symbol(name) for name in names)
+def symbolname(symbol):
+ if isinstance(symbol, str):
+ return symbol.strip()
+ elif isinstance(symbol, Symbol):
+ return symbol.name
+ else:
+ raise TypeError('symbol must be a string or a Symbol instance')
+
+def symbolnames(symbols):
+ if isinstance(symbols, str):
+ return symbols.replace(',', ' ').split()
+ return (symbolname(symbol) for symbol in symbols)
+
class Constant(Expression):
return self._hash
def coefficient(self, symbol):
- if isinstance(symbol, Symbol):
- symbol = symbol.name
- elif not isinstance(symbol, str):
- raise TypeError('symbol must be a string or a Symbol instance')
+ symbol = symbolname(symbol)
return 0
def coefficients(self):