1 # Copyright 2014 MINES ParisTech
3 # This file is part of LinPy.
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
21 from fractions
import Fraction
23 from ..linexprs
import *
24 from .libhelper
import requires_sympy
27 class TestLinExpr(unittest
.TestCase
):
33 self
.zero
= LinExpr(constant
=0)
34 self
.one
= LinExpr(constant
=1)
35 self
.pi
= LinExpr(constant
=Fraction(22, 7))
36 self
.expr
= self
.x
- 2*self
.y
+ 3
39 self
.assertIsInstance(LinExpr(coefficients
={self
.x
: 1}), Symbol
)
40 self
.assertIsInstance(LinExpr(constant
=self
.pi
), Rational
)
41 self
.assertNotIsInstance(self
.x
+ self
.pi
, Symbol
)
42 self
.assertNotIsInstance(self
.x
+ self
.pi
, Rational
)
43 xx
= LinExpr({self
.x
: 2})
44 self
.assertNotIsInstance(xx
, Symbol
)
45 with self
.assertRaises(TypeError):
47 with self
.assertRaises(TypeError):
49 with self
.assertRaises(TypeError):
51 self
.assertEqual(LinExpr(constant
=1), LinExpr(constant
=self
.one
))
52 self
.assertEqual(LinExpr(constant
='1'), LinExpr(constant
=self
.one
))
53 with self
.assertRaises(ValueError):
56 def test_coefficient(self
):
57 self
.assertEqual(self
.expr
.coefficient(self
.x
), 1)
58 self
.assertEqual(self
.expr
.coefficient(self
.y
), -2)
59 self
.assertEqual(self
.expr
.coefficient(self
.z
), 0)
60 with self
.assertRaises(TypeError):
61 self
.expr
.coefficients('x')
62 with self
.assertRaises(TypeError):
63 self
.expr
.coefficient(0)
64 with self
.assertRaises(TypeError):
65 self
.expr
.coefficient(self
.expr
)
67 def test_getitem(self
):
68 self
.assertEqual(self
.expr
[self
.x
], 1)
69 self
.assertEqual(self
.expr
[self
.y
], -2)
70 self
.assertEqual(self
.expr
[self
.z
], 0)
71 with self
.assertRaises(TypeError):
72 self
.assertEqual(self
.expr
['x'], 1)
73 with self
.assertRaises(TypeError):
75 with self
.assertRaises(TypeError):
78 def test_coefficients(self
):
79 self
.assertListEqual(list(self
.expr
.coefficients()), [(self
.x
, 1), (self
.y
, -2)])
81 def test_constant(self
):
82 self
.assertEqual(self
.x
.constant
, 0)
83 self
.assertEqual(self
.pi
.constant
, Fraction(22, 7))
84 self
.assertEqual(self
.expr
.constant
, 3)
86 def test_symbols(self
):
87 self
.assertTupleEqual(self
.x
.symbols
, (self
.x
,))
88 self
.assertTupleEqual(self
.pi
.symbols
, ())
89 self
.assertTupleEqual(self
.expr
.symbols
, (self
.x
, self
.y
))
91 def test_dimension(self
):
92 self
.assertEqual(self
.x
.dimension
, 1)
93 self
.assertEqual(self
.pi
.dimension
, 0)
94 self
.assertEqual(self
.expr
.dimension
, 2)
96 def test_isconstant(self
):
97 self
.assertFalse(self
.x
.isconstant())
98 self
.assertTrue(self
.pi
.isconstant())
99 self
.assertFalse(self
.expr
.isconstant())
101 def test_issymbol(self
):
102 self
.assertTrue(self
.x
.issymbol())
103 self
.assertFalse(self
.pi
.issymbol())
104 self
.assertFalse(self
.expr
.issymbol())
106 def test_values(self
):
107 self
.assertListEqual(list(self
.expr
.values()), [1, -2, 3])
110 self
.assertTrue(self
.x
)
111 self
.assertFalse(self
.zero
)
112 self
.assertTrue(self
.pi
)
113 self
.assertTrue(self
.expr
)
116 self
.assertEqual(+self
.expr
, self
.expr
)
119 self
.assertEqual(-self
.expr
, -self
.x
+ 2*self
.y
- 3)
122 self
.assertEqual(self
.x
+ Fraction(22, 7), self
.x
+ self
.pi
)
123 self
.assertEqual(Fraction(22, 7) + self
.x
, self
.x
+ self
.pi
)
124 self
.assertEqual(self
.x
+ self
.x
, 2 * self
.x
)
125 self
.assertEqual(self
.expr
+ 2*self
.y
, self
.x
+ 3)
128 self
.assertEqual(self
.x
- self
.x
, 0)
129 self
.assertEqual(self
.expr
- 3, self
.x
- 2*self
.y
)
130 self
.assertEqual(0 - self
.x
, -self
.x
)
133 self
.assertEqual(self
.pi
* 7, 22)
134 self
.assertEqual(self
.expr
* 0, 0)
135 self
.assertEqual(0 * self
.expr
, 0)
136 self
.assertEqual(self
.expr
* 2, 2*self
.x
- 4*self
.y
+ 6)
137 with self
.assertRaises(TypeError):
140 def test_truediv(self
):
141 with self
.assertRaises(ZeroDivisionError):
143 self
.assertEqual(self
.expr
/ 2, self
.x
/ 2 - self
.y
+ Fraction(3, 2))
144 with self
.assertRaises(TypeError):
148 self
.assertEqual(self
.expr
, self
.expr
)
149 self
.assertNotEqual(self
.x
, self
.y
)
150 self
.assertEqual(self
.zero
, 0)
152 def test_scaleint(self
):
153 self
.assertEqual((self
.x
+ self
.y
/2 + self
.z
/3).scaleint(),
154 6*self
.x
+ 3*self
.y
+ 2*self
.z
)
157 self
.assertEqual(self
.x
.subs(self
.x
, 3), 3)
158 self
.assertEqual(self
.x
.subs(self
.x
, self
.x
), self
.x
)
159 self
.assertEqual(self
.x
.subs(self
.x
, self
.y
), self
.y
)
160 self
.assertEqual(self
.x
.subs(self
.x
, self
.x
+ self
.y
), self
.x
+ self
.y
)
161 self
.assertEqual(self
.x
.subs(self
.y
, 3), self
.x
)
162 self
.assertEqual(self
.pi
.subs(self
.x
, 3), self
.pi
)
163 self
.assertEqual(self
.expr
.subs(self
.x
, -3), -2 * self
.y
)
164 self
.assertEqual(self
.expr
.subs([(self
.x
, self
.y
), (self
.y
, self
.x
)]), 3 - self
.x
)
165 self
.assertEqual(self
.expr
.subs({self
.x
: self
.z
, self
.y
: self
.z
}), 3 - self
.z
)
166 self
.assertEqual(self
.expr
.subs({self
.x
: self
.z
, self
.y
: self
.z
}), 3 - self
.z
)
167 with self
.assertRaises(TypeError):
169 with self
.assertRaises(TypeError):
170 self
.expr
.subs([('x', self
.z
), ('y', self
.z
)])
171 with self
.assertRaises(TypeError):
172 self
.expr
.subs({'x': self
.z
, 'y': self
.z
})
173 with self
.assertRaises(TypeError):
174 self
.expr
.subs(self
.x
, 'x')
176 def test_fromstring(self
):
177 self
.assertEqual(LinExpr
.fromstring('x'), self
.x
)
178 self
.assertEqual(LinExpr
.fromstring('-x'), -self
.x
)
179 self
.assertEqual(LinExpr
.fromstring('22/7'), self
.pi
)
180 self
.assertEqual(LinExpr
.fromstring('x - 2y + 3'), self
.expr
)
181 self
.assertEqual(LinExpr
.fromstring('x - (3-1)y + 3'), self
.expr
)
182 self
.assertEqual(LinExpr
.fromstring('x - 2*y + 3'), self
.expr
)
185 self
.assertEqual(str(LinExpr()), '0')
186 self
.assertEqual(str(self
.x
), 'x')
187 self
.assertEqual(str(-self
.x
), '-x')
188 self
.assertEqual(str(self
.pi
), '22/7')
189 self
.assertEqual(str(self
.expr
), 'x - 2*y + 3')
192 def test_fromsympy(self
):
194 sp_x
, sp_y
= sympy
.symbols('x y')
195 self
.assertEqual(LinExpr
.fromsympy(sp_x
), self
.x
)
196 self
.assertEqual(LinExpr
.fromsympy(sympy
.Rational(22, 7)), self
.pi
)
197 self
.assertEqual(LinExpr
.fromsympy(sp_x
- 2*sp_y
+ 3), self
.expr
)
198 with self
.assertRaises(ValueError):
199 LinExpr
.fromsympy(sp_x
*sp_y
)
202 def test_tosympy(self
):
204 sp_x
, sp_y
= sympy
.symbols('x y')
205 self
.assertEqual(self
.x
.tosympy(), sp_x
)
206 self
.assertEqual(self
.pi
.tosympy(), sympy
.Rational(22, 7))
207 self
.assertEqual(self
.expr
.tosympy(), sp_x
- 2*sp_y
+ 3)
210 class TestSymbol(unittest
.TestCase
):
217 self
.assertEqual(Symbol(' x '), self
.x
)
218 with self
.assertRaises(TypeError):
220 with self
.assertRaises(TypeError):
224 self
.assertEqual(self
.x
.name
, 'x')
226 def test_issymbol(self
):
227 self
.assertTrue(self
.x
.issymbol())
229 def test_fromstring(self
):
230 self
.assertEqual(Symbol
.fromstring('x'), self
.x
)
231 with self
.assertRaises(SyntaxError):
232 Symbol
.fromstring('1')
235 self
.assertEqual(str(self
.x
), 'x')
238 def test_fromsympy(self
):
240 sp_x
= sympy
.Symbol('x')
241 self
.assertEqual(Symbol
.fromsympy(sp_x
), self
.x
)
242 with self
.assertRaises(TypeError):
243 Symbol
.fromsympy(sympy
.Rational(22, 7))
244 with self
.assertRaises(TypeError):
245 Symbol
.fromsympy(2 * sp_x
)
246 with self
.assertRaises(TypeError):
247 Symbol
.fromsympy(sp_x
*sp_x
)
250 class TestDummy(unittest
.TestCase
):
256 self
.assertEqual(self
.x
.name
, 'x')
257 self
.assertTrue(Dummy().name
.startswith('Dummy'))
260 self
.assertEqual(self
.x
, self
.x
)
261 self
.assertNotEqual(self
.x
, Symbol('x'))
262 self
.assertNotEqual(Symbol('x'), self
.x
)
263 self
.assertNotEqual(self
.x
, Dummy('x'))
264 self
.assertNotEqual(Dummy(), Dummy())
267 self
.assertEqual(repr(self
.x
), '_x')
270 self
.assertTrue(repr(dummy1
).startswith('_Dummy_'))
271 self
.assertNotEqual(repr(dummy1
), repr(dummy2
))
274 class TestSymbols(unittest
.TestCase
):
281 self
.assertTupleEqual(symbols('x y'), (self
.x
, self
.y
))
282 self
.assertTupleEqual(symbols('x,y'), (self
.x
, self
.y
))
283 self
.assertTupleEqual(symbols(['x', 'y']), (self
.x
, self
.y
))
284 with self
.assertRaises(TypeError):
286 with self
.assertRaises(TypeError):
290 class TestRational(unittest
.TestCase
):
293 self
.zero
= Rational(0)
294 self
.one
= Rational(1)
295 self
.pi
= Rational(22, 7)
298 self
.assertEqual(Rational(), self
.zero
)
299 self
.assertEqual(Rational(1), self
.one
)
300 self
.assertEqual(Rational(self
.pi
), self
.pi
)
301 self
.assertEqual(Rational('22/7'), self
.pi
)
304 self
.assertEqual(hash(self
.one
), hash(1))
305 self
.assertEqual(hash(self
.pi
), hash(Fraction(22, 7)))
307 def test_isconstant(self
):
308 self
.assertTrue(self
.zero
.isconstant())
311 self
.assertFalse(self
.zero
)
312 self
.assertTrue(self
.pi
)
315 self
.assertEqual(repr(self
.zero
), '0')
316 self
.assertEqual(repr(self
.one
), '1')
317 self
.assertEqual(repr(self
.pi
), '22/7')
320 def test_fromsympy(self
):
322 self
.assertEqual(Rational
.fromsympy(sympy
.Rational(22, 7)), self
.pi
)
323 with self
.assertRaises(TypeError):
324 Rational
.fromsympy(sympy
.Symbol('x'))