bc062b6438f335c14d65444bbbbea6607e1f61fb
[linpy.git] / pypol / tests / test_linexprs.py
1 import functools
2 import unittest
3
4 from fractions import Fraction
5
6 from ..linexprs import *
7 from .libhelper import requires_sympy
8
9
10 class TestExpression(unittest.TestCase):
11
12 def setUp(self):
13 self.x = Expression({'x': 1})
14 self.y = Expression({'y': 1})
15 self.z = Expression({'z': 1})
16 self.zero = Expression(constant=0)
17 self.one = Expression(constant=1)
18 self.pi = Expression(constant=Fraction(22, 7))
19 self.expr = self.x - 2*self.y + 3
20
21 def test_new(self):
22 self.assertIsInstance(self.x, Symbol)
23 self.assertIsInstance(self.pi, Constant)
24 self.assertNotIsInstance(self.x + self.pi, Symbol)
25 self.assertNotIsInstance(self.x + self.pi, Constant)
26 xx = Expression({'x': 2})
27 self.assertNotIsInstance(xx, Symbol)
28 with self.assertRaises(TypeError):
29 Expression('x + y', 2)
30 self.assertEqual(Expression({'x': 2}), Expression({self.x: 2}))
31 with self.assertRaises(TypeError):
32 Expression({0: 2})
33 with self.assertRaises(TypeError):
34 Expression({'x': '2'})
35 self.assertEqual(Expression(constant=1), Expression(constant=self.one))
36 self.assertEqual(Expression(constant='1'), Expression(constant=self.one))
37 with self.assertRaises(ValueError):
38 Expression(constant='a')
39
40 def test_coefficient(self):
41 self.assertEqual(self.expr.coefficient('x'), 1)
42 self.assertEqual(self.expr.coefficient('y'), -2)
43 self.assertEqual(self.expr.coefficient(self.y), -2)
44 self.assertEqual(self.expr.coefficient('z'), 0)
45 with self.assertRaises(TypeError):
46 self.expr.coefficient(0)
47 with self.assertRaises(TypeError):
48 self.expr.coefficient(self.expr)
49
50 def test_getitem(self):
51 self.assertEqual(self.expr['x'], 1)
52 self.assertEqual(self.expr['y'], -2)
53 self.assertEqual(self.expr[self.y], -2)
54 self.assertEqual(self.expr['z'], 0)
55 with self.assertRaises(TypeError):
56 self.expr[0]
57 with self.assertRaises(TypeError):
58 self.expr[self.expr]
59
60 def test_coefficients(self):
61 self.assertCountEqual(self.expr.coefficients(), [('x', 1), ('y', -2)])
62
63 def test_constant(self):
64 self.assertEqual(self.x.constant, 0)
65 self.assertEqual(self.pi.constant, Fraction(22, 7))
66 self.assertEqual(self.expr.constant, 3)
67
68 def test_symbols(self):
69 self.assertCountEqual(self.x.symbols, ['x'])
70 self.assertCountEqual(self.pi.symbols, [])
71 self.assertCountEqual(self.expr.symbols, ['x', 'y'])
72
73 def test_dimension(self):
74 self.assertEqual(self.x.dimension, 1)
75 self.assertEqual(self.pi.dimension, 0)
76 self.assertEqual(self.expr.dimension, 2)
77
78 def test_isconstant(self):
79 self.assertFalse(self.x.isconstant())
80 self.assertTrue(self.pi.isconstant())
81 self.assertFalse(self.expr.isconstant())
82
83 def test_issymbol(self):
84 self.assertTrue(self.x.issymbol())
85 self.assertFalse(self.pi.issymbol())
86 self.assertFalse(self.expr.issymbol())
87
88 def test_values(self):
89 self.assertCountEqual(self.expr.values(), [1, -2, 3])
90
91 def test_bool(self):
92 self.assertTrue(self.x)
93 self.assertFalse(self.zero)
94 self.assertTrue(self.pi)
95 self.assertTrue(self.expr)
96
97 def test_pos(self):
98 self.assertEqual(+self.expr, self.expr)
99
100 def test_neg(self):
101 self.assertEqual(-self.expr, -self.x + 2*self.y - 3)
102
103 def test_add(self):
104 self.assertEqual(self.x + Fraction(22, 7), self.x + self.pi)
105 self.assertEqual(Fraction(22, 7) + self.x, self.x + self.pi)
106 self.assertEqual(self.x + self.x, 2 * self.x)
107 self.assertEqual(self.expr + 2*self.y, self.x + 3)
108
109 def test_sub(self):
110 self.assertEqual(self.x - self.x, 0)
111 self.assertEqual(self.expr - 3, self.x - 2*self.y)
112 self.assertEqual(0 - self.x, -self.x)
113
114 def test_mul(self):
115 self.assertEqual(self.pi * 7, 22)
116 self.assertEqual(self.expr * 0, 0)
117 self.assertEqual(0 * self.expr, 0)
118 self.assertEqual(self.expr * 2, 2*self.x - 4*self.y + 6)
119
120 def test_truediv(self):
121 with self.assertRaises(ZeroDivisionError):
122 self.expr / 0
123 self.assertEqual(self.expr / 2, self.x / 2 - self.y + Fraction(3, 2))
124
125 def test_eq(self):
126 self.assertEqual(self.expr, self.expr)
127 self.assertNotEqual(self.x, self.y)
128 self.assertEqual(self.zero, 0)
129
130 def test__toint(self):
131 self.assertEqual((self.x + self.y/2 + self.z/3)._toint(),
132 6*self.x + 3*self.y + 2*self.z)
133
134 def test_subs(self):
135 self.assertEqual(self.x.subs('x', 3), 3)
136 self.assertEqual(self.x.subs('x', self.x), self.x)
137 self.assertEqual(self.x.subs('x', self.y), self.y)
138 self.assertEqual(self.x.subs('x', self.x + self.y), self.x + self.y)
139 self.assertEqual(self.x.subs('y', 3), self.x)
140 self.assertEqual(self.pi.subs('x', 3), self.pi)
141 self.assertEqual(self.expr.subs('x', -3), -2 * self.y)
142 self.assertEqual(self.expr.subs([('x', self.y), ('y', self.x)]), 3 - self.x)
143 self.assertEqual(self.expr.subs({'x': self.z, 'y': self.z}), 3 - self.z)
144 self.assertEqual(self.expr.subs({self.x: self.z, self.y: self.z}), 3 - self.z)
145
146 def test_fromstring(self):
147 self.assertEqual(Expression.fromstring('x'), self.x)
148 self.assertEqual(Expression.fromstring('-x'), -self.x)
149 self.assertEqual(Expression.fromstring('22/7'), self.pi)
150 self.assertEqual(Expression.fromstring('x - 2y + 3'), self.expr)
151 self.assertEqual(Expression.fromstring('x - (3-1)y + 3'), self.expr)
152 self.assertEqual(Expression.fromstring('x - 2*y + 3'), self.expr)
153
154 def test_str(self):
155 self.assertEqual(str(Expression()), '0')
156 self.assertEqual(str(self.x), 'x')
157 self.assertEqual(str(-self.x), '-x')
158 self.assertEqual(str(self.pi), '22/7')
159 self.assertEqual(str(self.expr), 'x - 2*y + 3')
160
161 def test_repr(self):
162 self.assertEqual(repr(self.x), "Symbol('x')")
163 self.assertEqual(repr(self.one), 'Constant(1)')
164 self.assertEqual(repr(self.pi), 'Constant(22, 7)')
165 self.assertEqual(repr(self.x + self.one), "Expression('x + 1')")
166 self.assertEqual(repr(self.expr), "Expression('x - 2*y + 3')")
167
168 @requires_sympy
169 def test_fromsympy(self):
170 import sympy
171 sp_x, sp_y = sympy.symbols('x y')
172 self.assertEqual(Expression.fromsympy(sp_x), self.x)
173 self.assertEqual(Expression.fromsympy(sympy.Rational(22, 7)), self.pi)
174 self.assertEqual(Expression.fromsympy(sp_x - 2*sp_y + 3), self.expr)
175 with self.assertRaises(ValueError):
176 Expression.fromsympy(sp_x*sp_y)
177
178 @requires_sympy
179 def test_tosympy(self):
180 import sympy
181 sp_x, sp_y = sympy.symbols('x y')
182 self.assertEqual(self.x.tosympy(), sp_x)
183 self.assertEqual(self.pi.tosympy(), sympy.Rational(22, 7))
184 self.assertEqual(self.expr.tosympy(), sp_x - 2*sp_y + 3)
185
186
187 class TestSymbol(unittest.TestCase):
188
189 def setUp(self):
190 self.x = Symbol('x')
191 self.y = Symbol('y')
192
193 def test_new(self):
194 self.assertEqual(Symbol(' x '), self.x)
195 self.assertEqual(Symbol(self.x), self.x)
196 with self.assertRaises(TypeError):
197 Symbol(1)
198
199 def test_name(self):
200 self.assertEqual(self.x.name, 'x')
201
202 def test_issymbol(self):
203 self.assertTrue(self.x.issymbol())
204
205 def test_fromstring(self):
206 self.assertEqual(Symbol.fromstring('x'), self.x)
207 with self.assertRaises(SyntaxError):
208 Symbol.fromstring('1')
209
210 def test_str(self):
211 self.assertEqual(str(self.x), 'x')
212
213 def test_repr(self):
214 self.assertEqual(repr(self.x), "Symbol('x')")
215
216 @requires_sympy
217 def test_fromsympy(self):
218 import sympy
219 sp_x = sympy.Symbol('x')
220 self.assertEqual(Symbol.fromsympy(sp_x), self.x)
221 with self.assertRaises(TypeError):
222 Symbol.fromsympy(sympy.Rational(22, 7))
223 with self.assertRaises(TypeError):
224 Symbol.fromsympy(2 * sp_x)
225 with self.assertRaises(TypeError):
226 Symbol.fromsympy(sp_x*sp_x)
227
228 def test_symbols(self):
229 self.assertListEqual(list(symbols('x y')), [self.x, self.y])
230 self.assertListEqual(list(symbols('x,y')), [self.x, self.y])
231 self.assertListEqual(list(symbols(['x', 'y'])), [self.x, self.y])
232
233
234 class TestConstant(unittest.TestCase):
235
236 def setUp(self):
237 self.zero = Constant(0)
238 self.one = Constant(1)
239 self.pi = Constant(Fraction(22, 7))
240
241 def test_new(self):
242 self.assertEqual(Constant(), self.zero)
243 self.assertEqual(Constant(1), self.one)
244 self.assertEqual(Constant(self.pi), self.pi)
245 self.assertEqual(Constant('22/7'), self.pi)
246
247 def test_isconstant(self):
248 self.assertTrue(self.zero.isconstant())
249
250 def test_bool(self):
251 self.assertFalse(self.zero)
252 self.assertTrue(self.pi)
253
254 def test_fromstring(self):
255 self.assertEqual(Constant.fromstring('22/7'), self.pi)
256 with self.assertRaises(ValueError):
257 Constant.fromstring('a')
258 with self.assertRaises(TypeError):
259 Constant.fromstring(1)
260
261 def test_repr(self):
262 self.assertEqual(repr(self.zero), 'Constant(0)')
263 self.assertEqual(repr(self.one), 'Constant(1)')
264 self.assertEqual(repr(self.pi), 'Constant(22, 7)')
265
266 @requires_sympy
267 def test_fromsympy(self):
268 import sympy
269 self.assertEqual(Constant.fromsympy(sympy.Rational(22, 7)), self.pi)
270 with self.assertRaises(TypeError):
271 Constant.fromsympy(sympy.Symbol('x'))