Add _repr_latex_ methods for IPython prettyprint
[linpy.git] / pypol / linexprs.py
index e73449e..c8745b5 100644 (file)
@@ -44,11 +44,11 @@ class Expression:
             if not isinstance(symbol, Symbol):
                 raise TypeError('symbols must be Symbol instances')
             if not isinstance(coefficient, numbers.Rational):
             if not isinstance(symbol, Symbol):
                 raise TypeError('symbols must be Symbol instances')
             if not isinstance(coefficient, numbers.Rational):
-                raise TypeError('coefficients must be Rational instances')
+                raise TypeError('coefficients must be rational numbers')
         coefficients = [(symbol, Fraction(coefficient))
             for symbol, coefficient in coefficients if coefficient != 0]
         if not isinstance(constant, numbers.Rational):
         coefficients = [(symbol, Fraction(coefficient))
             for symbol, coefficient in coefficients if coefficient != 0]
         if not isinstance(constant, numbers.Rational):
-            raise TypeError('constant must be a Rational instance')
+            raise TypeError('constant must be a rational number')
         constant = Fraction(constant)
         if len(coefficients) == 0:
             return Rational(constant)
         constant = Fraction(constant)
         if len(coefficients) == 0:
             return Rational(constant)
@@ -240,18 +240,17 @@ class Expression:
         string = ''
         for i, (symbol, coefficient) in enumerate(self.coefficients()):
             if coefficient == 1:
         string = ''
         for i, (symbol, coefficient) in enumerate(self.coefficients()):
             if coefficient == 1:
-                string += '' if i == 0 else ' + '
-                string += '{!r}'.format(symbol)
+                if i != 0:
+                    string += ' + '
             elif coefficient == -1:
                 string += '-' if i == 0 else ' - '
             elif coefficient == -1:
                 string += '-' if i == 0 else ' - '
-                string += '{!r}'.format(symbol)
+            elif i == 0:
+                string += '{}*'.format(coefficient)
+            elif coefficient > 0:
+                string += ' + {}*'.format(coefficient)
             else:
             else:
-                if i == 0:
-                    string += '{}*{!r}'.format(coefficient, symbol)
-                elif coefficient > 0:
-                    string += ' + {}*{!r}'.format(coefficient, symbol)
-                else:
-                    string += ' - {}*{!r}'.format(-coefficient, symbol)
+                string += ' - {}*'.format(-coefficient)
+            string += '{}'.format(symbol)
         constant = self.constant
         if len(string) == 0:
             string += '{}'.format(constant)
         constant = self.constant
         if len(string) == 0:
             string += '{}'.format(constant)
@@ -261,6 +260,30 @@ class Expression:
             string += ' - {}'.format(-constant)
         return string
 
             string += ' - {}'.format(-constant)
         return string
 
+    def _repr_latex_(self):
+        string = ''
+        for i, (symbol, coefficient) in enumerate(self.coefficients()):
+            if coefficient == 1:
+                if i != 0:
+                    string += ' + '
+            elif coefficient == -1:
+                string += '-' if i == 0 else ' - '
+            elif i == 0:
+                string += '{}'.format(coefficient._repr_latex_().strip('$'))
+            elif coefficient > 0:
+                string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
+            elif coefficient < 0:
+                string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
+            string += '{}'.format(symbol._repr_latex_().strip('$'))
+        constant = self.constant
+        if len(string) == 0:
+            string += '{}'.format(constant._repr_latex_().strip('$'))
+        elif constant > 0:
+            string += ' + {}'.format(constant._repr_latex_().strip('$'))
+        elif constant < 0:
+            string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
+        return '${}$'.format(string)
+
     def _parenstr(self, always=False):
         string = str(self)
         if not always and (self.isconstant() or self.issymbol()):
     def _parenstr(self, always=False):
         string = str(self)
         if not always and (self.isconstant() or self.issymbol()):
@@ -340,6 +363,9 @@ class Symbol(Expression):
     def __repr__(self):
         return self.name
 
     def __repr__(self):
         return self.name
 
+    def _repr_latex_(self):
+        return '${}$'.format(self.name)
+
     @classmethod
     def fromsympy(cls, expr):
         import sympy
     @classmethod
     def fromsympy(cls, expr):
         import sympy
@@ -378,6 +404,9 @@ class Dummy(Symbol):
     def __repr__(self):
         return '_{}'.format(self.name)
 
     def __repr__(self):
         return '_{}'.format(self.name)
 
+    def _repr_latex_(self):
+        return '${}_{{{}}}$'.format(self.name, self._index)
+
 
 def symbols(names):
     if isinstance(names, str):
 
 def symbols(names):
     if isinstance(names, str):
@@ -430,7 +459,23 @@ class Rational(Expression, Fraction):
     def fromstring(cls, string):
         if not isinstance(string, str):
             raise TypeError('string must be a string instance')
     def fromstring(cls, string):
         if not isinstance(string, str):
             raise TypeError('string must be a string instance')
-        return Rational(Fraction(string))
+        return Rational(string)
+
+    def __repr__(self):
+        if self.denominator == 1:
+            return '{!r}'.format(self.numerator)
+        else:
+            return '{!r}/{!r}'.format(self.numerator, self.denominator)
+
+    def _repr_latex_(self):
+        if self.denominator == 1:
+            return '${}$'.format(self.numerator)
+        elif self.numerator < 0:
+            return '$-\\frac{{{}}}{{{}}}$'.format(-self.numerator,
+                self.denominator)
+        else:
+            return '$\\frac{{{}}}{{{}}}$'.format(self.numerator,
+                self.denominator)
 
     @classmethod
     def fromsympy(cls, expr):
 
     @classmethod
     def fromsympy(cls, expr):