1 # Copyright 2014 MINES ParisTech
3 # This file is part of LinPy.
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
20 from fractions
import Fraction
22 from ..linexprs
import Dummy
, LinExpr
, Rational
, Symbol
, symbols
23 from .libhelper
import requires_sympy
26 class TestLinExpr(unittest
.TestCase
):
32 self
.zero
= LinExpr(constant
=0)
33 self
.one
= LinExpr(constant
=1)
34 self
.pi
= LinExpr(constant
=Fraction(22, 7))
35 self
.expr
= self
.x
- 2*self
.y
+ 3
38 self
.assertIsInstance(LinExpr(coefficients
={self
.x
: 1}), Symbol
)
39 self
.assertIsInstance(LinExpr(constant
=self
.pi
), Rational
)
40 self
.assertNotIsInstance(self
.x
+ self
.pi
, Symbol
)
41 self
.assertNotIsInstance(self
.x
+ self
.pi
, Rational
)
42 xx
= LinExpr({self
.x
: 2})
43 self
.assertNotIsInstance(xx
, Symbol
)
44 with self
.assertRaises(TypeError):
46 with self
.assertRaises(TypeError):
48 with self
.assertRaises(TypeError):
50 self
.assertEqual(LinExpr(constant
=1), LinExpr(constant
=self
.one
))
51 self
.assertEqual(LinExpr(constant
='1'), LinExpr(constant
=self
.one
))
52 with self
.assertRaises(ValueError):
55 def test_coefficient(self
):
56 self
.assertEqual(self
.expr
.coefficient(self
.x
), 1)
57 self
.assertEqual(self
.expr
.coefficient(self
.y
), -2)
58 self
.assertEqual(self
.expr
.coefficient(self
.z
), 0)
59 with self
.assertRaises(TypeError):
60 self
.expr
.coefficients('x')
61 with self
.assertRaises(TypeError):
62 self
.expr
.coefficient(0)
63 with self
.assertRaises(TypeError):
64 self
.expr
.coefficient(self
.expr
)
66 def test_getitem(self
):
67 self
.assertEqual(self
.expr
[self
.x
], 1)
68 self
.assertEqual(self
.expr
[self
.y
], -2)
69 self
.assertEqual(self
.expr
[self
.z
], 0)
70 with self
.assertRaises(TypeError):
71 self
.assertEqual(self
.expr
['x'], 1)
72 with self
.assertRaises(TypeError):
74 with self
.assertRaises(TypeError):
77 def test_coefficients(self
):
78 self
.assertListEqual(list(self
.expr
.coefficients()),
79 [(self
.x
, 1), (self
.y
, -2)])
81 def test_constant(self
):
82 self
.assertEqual(self
.x
.constant
, 0)
83 self
.assertEqual(self
.pi
.constant
, Fraction(22, 7))
84 self
.assertEqual(self
.expr
.constant
, 3)
86 def test_symbols(self
):
87 self
.assertTupleEqual(self
.x
.symbols
, (self
.x
,))
88 self
.assertTupleEqual(self
.pi
.symbols
, ())
89 self
.assertTupleEqual(self
.expr
.symbols
, (self
.x
, self
.y
))
91 def test_dimension(self
):
92 self
.assertEqual(self
.x
.dimension
, 1)
93 self
.assertEqual(self
.pi
.dimension
, 0)
94 self
.assertEqual(self
.expr
.dimension
, 2)
96 def test_isconstant(self
):
97 self
.assertFalse(self
.x
.isconstant())
98 self
.assertTrue(self
.pi
.isconstant())
99 self
.assertFalse(self
.expr
.isconstant())
101 def test_issymbol(self
):
102 self
.assertTrue(self
.x
.issymbol())
103 self
.assertFalse(self
.pi
.issymbol())
104 self
.assertFalse(self
.expr
.issymbol())
106 def test_values(self
):
107 self
.assertListEqual(list(self
.expr
.values()), [1, -2, 3])
110 self
.assertTrue(self
.x
)
111 self
.assertFalse(self
.zero
)
112 self
.assertTrue(self
.pi
)
113 self
.assertTrue(self
.expr
)
116 self
.assertEqual(+self
.expr
, self
.expr
)
119 self
.assertEqual(-self
.expr
, -self
.x
+ 2*self
.y
- 3)
122 self
.assertEqual(self
.x
+ Fraction(22, 7), self
.x
+ self
.pi
)
123 self
.assertEqual(Fraction(22, 7) + self
.x
, self
.x
+ self
.pi
)
124 self
.assertEqual(self
.x
+ self
.x
, 2 * self
.x
)
125 self
.assertEqual(self
.expr
+ 2*self
.y
, self
.x
+ 3)
128 self
.assertEqual(self
.x
- self
.x
, 0)
129 self
.assertEqual(self
.expr
- 3, self
.x
- 2*self
.y
)
130 self
.assertEqual(0 - self
.x
, -self
.x
)
133 self
.assertEqual(self
.pi
* 7, 22)
134 self
.assertEqual(self
.expr
* 0, 0)
135 self
.assertEqual(0 * self
.expr
, 0)
136 self
.assertEqual(self
.expr
* 2, 2*self
.x
- 4*self
.y
+ 6)
137 with self
.assertRaises(TypeError):
140 def test_truediv(self
):
141 with self
.assertRaises(ZeroDivisionError):
143 self
.assertEqual(self
.expr
/ 2, self
.x
/ 2 - self
.y
+ Fraction(3, 2))
144 with self
.assertRaises(TypeError):
148 self
.assertEqual(self
.expr
, self
.expr
)
149 self
.assertNotEqual(self
.x
, self
.y
)
150 self
.assertEqual(self
.zero
, 0)
152 def test_scaleint(self
):
153 self
.assertEqual((self
.x
+ self
.y
/2 + self
.z
/3).scaleint(),
154 6*self
.x
+ 3*self
.y
+ 2*self
.z
)
157 self
.assertEqual(self
.x
.subs(self
.x
, 3), 3)
158 self
.assertEqual(self
.x
.subs(self
.x
, self
.x
), self
.x
)
159 self
.assertEqual(self
.x
.subs(self
.x
, self
.y
), self
.y
)
160 self
.assertEqual(self
.x
.subs(self
.x
, self
.x
+ self
.y
), self
.x
+ self
.y
)
161 self
.assertEqual(self
.x
.subs(self
.y
, 3), self
.x
)
162 self
.assertEqual(self
.pi
.subs(self
.x
, 3), self
.pi
)
163 self
.assertEqual(self
.expr
.subs(self
.x
, -3), -2 * self
.y
)
164 self
.assertEqual(self
.expr
.subs([(self
.x
, self
.y
), (self
.y
, self
.x
)]),
165 -2*self
.x
+ self
.y
+ 3)
166 self
.assertEqual(self
.expr
.subs({self
.x
: self
.z
, self
.y
: self
.z
}),
168 self
.assertEqual(self
.expr
.subs({self
.x
: self
.z
, self
.y
: self
.z
}),
170 with self
.assertRaises(TypeError):
172 with self
.assertRaises(TypeError):
173 self
.expr
.subs([('x', self
.z
), ('y', self
.z
)])
174 with self
.assertRaises(TypeError):
175 self
.expr
.subs({'x': self
.z
, 'y': self
.z
})
176 with self
.assertRaises(TypeError):
177 self
.expr
.subs(self
.x
, 'x')
179 def test_fromstring(self
):
180 self
.assertEqual(LinExpr
.fromstring('x'), self
.x
)
181 self
.assertEqual(LinExpr
.fromstring('-x'), -self
.x
)
182 self
.assertEqual(LinExpr
.fromstring('22/7'), self
.pi
)
183 self
.assertEqual(LinExpr
.fromstring('x - 2y + 3'), self
.expr
)
184 self
.assertEqual(LinExpr
.fromstring('x - (3-1)y + 3'), self
.expr
)
185 self
.assertEqual(LinExpr
.fromstring('x - 2*y + 3'), self
.expr
)
188 self
.assertEqual(str(LinExpr()), '0')
189 self
.assertEqual(str(self
.x
), 'x')
190 self
.assertEqual(str(-self
.x
), '-x')
191 self
.assertEqual(str(self
.pi
), '22/7')
192 self
.assertEqual(str(self
.expr
), 'x - 2*y + 3')
195 def test_fromsympy(self
):
197 sp_x
, sp_y
= sympy
.symbols('x y')
198 self
.assertEqual(LinExpr
.fromsympy(sp_x
), self
.x
)
199 self
.assertEqual(LinExpr
.fromsympy(sympy
.Rational(22, 7)), self
.pi
)
200 self
.assertEqual(LinExpr
.fromsympy(sp_x
- 2*sp_y
+ 3), self
.expr
)
201 with self
.assertRaises(TypeError):
202 LinExpr
.fromsympy(sp_x
*sp_y
)
205 def test_tosympy(self
):
207 sp_x
, sp_y
= sympy
.symbols('x y')
208 self
.assertEqual(self
.x
.tosympy(), sp_x
)
209 self
.assertEqual(self
.pi
.tosympy(), sympy
.Rational(22, 7))
210 self
.assertEqual(self
.expr
.tosympy(), sp_x
- 2*sp_y
+ 3)
213 class TestSymbol(unittest
.TestCase
):
220 self
.assertEqual(Symbol('x'), self
.x
)
221 with self
.assertRaises(TypeError):
223 with self
.assertRaises(TypeError):
225 with self
.assertRaises(SyntaxError):
227 with self
.assertRaises(SyntaxError):
229 with self
.assertRaises(SyntaxError):
236 self
.assertEqual(self
.x
.name
, 'x')
238 def test_issymbol(self
):
239 self
.assertTrue(self
.x
.issymbol())
241 def test_fromstring(self
):
242 self
.assertEqual(Symbol
.fromstring('x'), self
.x
)
243 with self
.assertRaises(SyntaxError):
244 Symbol
.fromstring('1')
247 self
.assertEqual(str(self
.x
), 'x')
250 def test_fromsympy(self
):
252 sp_x
= sympy
.Symbol('x')
253 self
.assertEqual(Symbol
.fromsympy(sp_x
), self
.x
)
254 with self
.assertRaises(TypeError):
255 Symbol
.fromsympy(sympy
.Rational(22, 7))
256 with self
.assertRaises(TypeError):
257 Symbol
.fromsympy(2 * sp_x
)
258 with self
.assertRaises(TypeError):
259 Symbol
.fromsympy(sp_x
*sp_x
)
262 class TestDummy(unittest
.TestCase
):
268 self
.assertEqual(self
.x
.name
, 'x')
269 self
.assertTrue(Dummy().name
.startswith('Dummy'))
272 self
.assertEqual(self
.x
, self
.x
)
273 self
.assertNotEqual(self
.x
, Symbol('x'))
274 self
.assertNotEqual(Symbol('x'), self
.x
)
275 self
.assertNotEqual(self
.x
, Dummy('x'))
276 self
.assertNotEqual(Dummy(), Dummy())
279 self
.assertEqual(repr(self
.x
), '_x')
282 self
.assertTrue(repr(dummy1
).startswith('_Dummy_'))
283 self
.assertNotEqual(repr(dummy1
), repr(dummy2
))
286 class TestSymbols(unittest
.TestCase
):
293 self
.assertTupleEqual(symbols('x y'), (self
.x
, self
.y
))
294 self
.assertTupleEqual(symbols('x,y'), (self
.x
, self
.y
))
295 self
.assertTupleEqual(symbols(['x', 'y']), (self
.x
, self
.y
))
296 with self
.assertRaises(TypeError):
298 with self
.assertRaises(TypeError):
302 class TestRational(unittest
.TestCase
):
305 self
.zero
= Rational(0)
306 self
.one
= Rational(1)
307 self
.pi
= Rational(22, 7)
310 self
.assertEqual(Rational(), self
.zero
)
311 self
.assertEqual(Rational(1), self
.one
)
312 self
.assertEqual(Rational(self
.pi
), self
.pi
)
313 self
.assertEqual(Rational('22/7'), self
.pi
)
316 self
.assertEqual(hash(self
.one
), hash(1))
317 self
.assertEqual(hash(self
.pi
), hash(Fraction(22, 7)))
319 def test_isconstant(self
):
320 self
.assertTrue(self
.zero
.isconstant())
323 self
.assertFalse(self
.zero
)
324 self
.assertTrue(self
.pi
)
327 self
.assertEqual(repr(self
.zero
), '0')
328 self
.assertEqual(repr(self
.one
), '1')
329 self
.assertEqual(repr(self
.pi
), '22/7')
332 def test_fromsympy(self
):
334 self
.assertEqual(Rational
.fromsympy(sympy
.Rational(22, 7)), self
.pi
)
335 with self
.assertRaises(TypeError):
336 Rational
.fromsympy(sympy
.Symbol('x'))