X-Git-Url: https://scm.cri.ensmp.fr/git/linpy.git/blobdiff_plain/2e558859456a109279713a2cbdd6c48a70a171c6..b02f9551644488e5943f968ac847fe4ed7690d6b:/linpy/linexprs.py diff --git a/linpy/linexprs.py b/linpy/linexprs.py index 834c3b4..b2cec53 100644 --- a/linpy/linexprs.py +++ b/linpy/linexprs.py @@ -122,7 +122,7 @@ class LinExpr: """ if not isinstance(symbol, Symbol): raise TypeError('symbol must be a Symbol instance') - return Rational(self._coefficients.get(symbol, 0)) + return self._coefficients.get(symbol, 0) __getitem__ = coefficient @@ -132,14 +132,14 @@ class LinExpr: expression. The constant term is ignored. """ for symbol, coefficient in self._coefficients.items(): - yield symbol, Rational(coefficient) + yield symbol, coefficient @property def constant(self): """ The constant term of the expression. """ - return Rational(self._constant) + return self._constant @property def symbols(self): @@ -180,8 +180,8 @@ class LinExpr: term. """ for coefficient in self._coefficients.values(): - yield Rational(coefficient) - yield Rational(self._constant) + yield coefficient + yield self._constant def __bool__(self): return True @@ -411,7 +411,7 @@ class LinExpr: @classmethod def fromsympy(cls, expr): """ - Create a linear expression from a sympy expression. Raise ValueError is + Create a linear expression from a sympy expression. Raise TypeError is the sympy expression is not linear. """ import sympy @@ -421,12 +421,18 @@ class LinExpr: coefficient = Fraction(coefficient.p, coefficient.q) if symbol == sympy.S.One: constant = coefficient + elif isinstance(symbol, sympy.Dummy): + # we cannot properly convert dummy symbols + raise TypeError('cannot convert dummy symbols') elif isinstance(symbol, sympy.Symbol): symbol = Symbol(symbol.name) coefficients.append((symbol, coefficient)) else: - raise ValueError('non-linear expression: {!r}'.format(expr)) - return LinExpr(coefficients, constant) + raise TypeError('non-linear expression: {!r}'.format(expr)) + expr = LinExpr(coefficients, constant) + if not isinstance(expr, cls): + raise TypeError('cannot convert to a {} instance'.format(cls.__name__)) + return expr def tosympy(self): """ @@ -507,15 +513,20 @@ class Symbol(LinExpr): def _repr_latex_(self): return '$${}$$'.format(self.name) - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Dummy): - return Dummy(expr.name) - elif isinstance(expr, sympy.Symbol): - return Symbol(expr.name) - else: - raise TypeError('expr must be a sympy.Symbol instance') + +def symbols(names): + """ + This function returns a tuple of symbols whose names are taken from a comma + or whitespace delimited string, or a sequence of strings. It is useful to + define several symbols at once. + + >>> x, y = symbols('x y') + >>> x, y = symbols('x, y') + >>> x, y = symbols(['x', 'y']) + """ + if isinstance(names, str): + names = names.replace(',', ' ').split() + return tuple(Symbol(name) for name in names) class Dummy(Symbol): @@ -571,21 +582,6 @@ class Dummy(Symbol): return '$${}_{{{}}}$$'.format(self.name, self._index) -def symbols(names): - """ - This function returns a tuple of symbols whose names are taken from a comma - or whitespace delimited string, or a sequence of strings. It is useful to - define several symbols at once. - - >>> x, y = symbols('x y') - >>> x, y = symbols('x, y') - >>> x, y = symbols(['x', 'y']) - """ - if isinstance(names, str): - names = names.replace(',', ' ').split() - return tuple(Symbol(name) for name in names) - - class Rational(LinExpr, Fraction): """ A particular case of linear expressions are rational values, i.e. linear @@ -632,13 +628,3 @@ class Rational(LinExpr, Fraction): else: return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator, self.denominator) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Rational): - return Rational(expr.p, expr.q) - elif isinstance(expr, numbers.Rational): - return Rational(expr) - else: - raise TypeError('expr must be a sympy.Rational instance')