From 841943174bb4d3b602e8e055592d8b54d1bb086d Mon Sep 17 00:00:00 2001 From: Vivien Maisonneuve Date: Thu, 22 May 2014 23:59:53 +0200 Subject: [PATCH] Initial commit --- .gitignore | 4 + Makefile | 18 ++ pypol/__init__.py | 15 ++ pypol/isl.py | 267 +++++++++++++++++++++++++ pypol/linear.py | 458 +++++++++++++++++++++++++++++++++++++++++++ setup.py | 10 + tests/__init__.py | 0 tests/test_isl.py | 199 +++++++++++++++++++ tests/test_linear.py | 176 +++++++++++++++++ 9 files changed, 1147 insertions(+) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 pypol/__init__.py create mode 100644 pypol/isl.py create mode 100644 pypol/linear.py create mode 100755 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/test_isl.py create mode 100644 tests/test_linear.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bbe9588 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/build/ +/dist/ +/MANIFEST +__pycache__ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4ab5d2a --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +PYTHON=python3 +RM=rm -rf + +.PHONY: default +default: + @echo "pypol - A polyhedral library based on ISL" + @echo + @echo "Makefile usage:" + @echo " make test run the test suite" + @echo " make clean remove the generated files" + +.PHONY: test +test: + $(PYTHON) -m unittest + +.PHONY: clean +clean: + $(RM) build dist MANIFEST pypol/__pycache__ tests/__pycache__ diff --git a/pypol/__init__.py b/pypol/__init__.py new file mode 100644 index 0000000..fde0347 --- /dev/null +++ b/pypol/__init__.py @@ -0,0 +1,15 @@ + +""" +A polyhedral library based on ISL. +""" + +from .linear import constant, symbol, symbols +from .linear import eq, le, lt, ge, gt +from .linear import empty, universe + + +__all__ = [ + 'constant', 'symbol', 'symbols', + 'eq', 'le', 'lt', 'ge', 'gt', + 'empty', 'universe' +] diff --git a/pypol/isl.py b/pypol/isl.py new file mode 100644 index 0000000..758fb79 --- /dev/null +++ b/pypol/isl.py @@ -0,0 +1,267 @@ + +import ctypes, ctypes.util +import math +import numbers +import operator +import re + +from decimal import Decimal +from fractions import Fraction +from functools import wraps + + +libisl = ctypes.CDLL(ctypes.util.find_library('isl')) + +libisl.isl_printer_get_str.restype = ctypes.c_char_p + + +class Context: + + __slots__ = ('_ic') + + def __init__(self): + self._ic = libisl.isl_ctx_alloc() + + @property + def _as_parameter_(self): + return self._ic + + def __del__(self): + libisl.isl_ctx_free(self) + + def __eq__(self, other): + if not isinstance(other, Context): + return False + return self._ic == other._ic + + +class Value: + + class _ptr(int): + def __new__(cls, iv): + return super().__new__(cls, iv) + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, self) + + _RE_NONFINITE = re.compile( + r'^\s*(?P[-+])?((?PInf(inity)?)|(?PNaN))\s*$', + re.IGNORECASE) + + _RE_FRACTION = re.compile(r'^(?P[-+]?\d+)(/(?P\d+))?$') + + __slots__ = ('context', '_iv', '_numerator', '_denominator') + + def __new__(cls, context, numerator=0, denominator=None): + self = super().__new__(cls) + if not isinstance(context, Context): + raise TypeError('first argument should be a context') + self.context = context + if isinstance(numerator, cls._ptr): + assert denominator is None + self._iv = numerator + if libisl.isl_val_is_rat(self): + # retrieve numerator and denominator as strings to avoid integer + # overflows + ip = libisl.isl_printer_to_str(self.context) + ip = libisl.isl_printer_print_val(ip, self) + string = libisl.isl_printer_get_str(ip).decode() + libisl.isl_printer_free(ip) + m = self._RE_FRACTION.match(string) + assert m is not None + self._numerator = int(m.group('num')) + self._denominator = int(m.group('den')) if m.group('den') else 1 + else: + self._numerator = None + self._denominator = None + return self + if isinstance(numerator, str) and denominator is None: + m = self._RE_NONFINITE.match(numerator) + if m is not None: + self._numerator = None + self._denominator = None + if m.group('inf'): + if m.group('sign') == '-': + self._iv = libisl.isl_val_neginfty(context) + else: + self._iv = libisl.isl_val_infty(context) + else: + assert m.group('nan') + self._iv = libisl.isl_val_nan(context) + return self + try: + frac = Fraction(numerator, denominator) + except ValueError: + raise ValueError('invalid literal for {}: {!r}'.format( + cls.__name__, numerator)) + self._numerator = frac.numerator + self._denominator = frac.denominator + # values passed as strings to avoid integer overflows + if frac.denominator == 1: + numerator = str(frac.numerator).encode() + self._iv = libisl.isl_val_read_from_str(context, numerator) + else: + numerator = str(frac.numerator).encode() + numerator = libisl.isl_val_read_from_str(context, numerator) + denominator = str(frac.denominator).encode() + denominator = libisl.isl_val_read_from_str(context, denominator) + self._iv = libisl.isl_val_div(numerator, denominator) + return self + + @property + def _as_parameter_(self): + return self._iv + + def __del__(self): + libisl.isl_val_free(self) + self.context # prevents context from being GC'ed before the value + + @property + def numerator(self): + if self._numerator is None: + raise ValueError('not a rational number') + return self._numerator + + @property + def denominator(self): + if self._denominator is None: + raise ValueError('not a rational number') + return self._denominator + + def __bool__(self): + return not bool(libisl.isl_val_is_zero(self)) + + def _polymorphic(func): + @wraps(func) + def wrapper(self, other): + if isinstance(other, Value): + return func(self, other) + if isinstance(other, numbers.Rational): + other = Value(self.context, other) + return func(self, other) + raise TypeError('operand should be a Value or a Rational') + return wrapper + + @_polymorphic + def __lt__(self, other): + return bool(libisl.isl_val_lt(self, other)) + + @_polymorphic + def __le__(self, other): + return bool(libisl.isl_val_le(self, other)) + + @_polymorphic + def __gt__(self, other): + return bool(libisl.isl_val_gt(self, other)) + + @_polymorphic + def __ge__(self, other): + return bool(libisl.isl_val_ge(self, other)) + + @_polymorphic + def __eq__(self, other): + return bool(libisl.isl_val_eq(self, other)) + + # __ne__ is not implemented, ISL semantics does not match Python's on + # nan != nan + + def __abs__(self): + val = libisl.isl_val_copy(self) + val = libisl.isl_val_abs(val) + return self.__class__(self.context, self._ptr(val)) + + def __pos__(self): + return self + + def __neg__(self): + val = libisl.isl_val_copy(self) + val = libisl.isl_val_neg(val) + return self.__class__(self.context, self._ptr(val)) + + def __floor__(self): + val = libisl.isl_val_copy(self) + val = libisl.isl_val_floor(val) + return self.__class__(self.context, self._ptr(val)) + + def __ceil__(self): + val = libisl.isl_val_copy(self) + val = libisl.isl_val_ceil(val) + return self.__class__(self.context, self._ptr(val)) + + def __trunc__(self): + val = libisl.isl_val_copy(self) + val = libisl.isl_val_trunc(val) + return self.__class__(self.context, self._ptr(val)) + + @_polymorphic + def __add__(self, other): + val1 = libisl.isl_val_copy(self) + val2 = libisl.isl_val_copy(other) + val = libisl.isl_val_add(val1, val2) + return self.__class__(self.context, self._ptr(val)) + + __radd__ = __add__ + + @_polymorphic + def __sub__(self, other): + val1 = libisl.isl_val_copy(self) + val2 = libisl.isl_val_copy(other) + val = libisl.isl_val_sub(val1, val2) + return self.__class__(self.context, self._ptr(val)) + + __rsub__ = __sub__ + + @_polymorphic + def __mul__(self, other): + val1 = libisl.isl_val_copy(self) + val2 = libisl.isl_val_copy(other) + val = libisl.isl_val_mul(val1, val2) + return self.__class__(self.context, self._ptr(val)) + + __rmul__ = __mul__ + + @_polymorphic + def __truediv__(self, other): + val1 = libisl.isl_val_copy(self) + val2 = libisl.isl_val_copy(other) + val = libisl.isl_val_div(val1, val2) + return self.__class__(self.context, self._ptr(val)) + + __rtruediv__ = __truediv__ + + def __float__(self): + if libisl.isl_val_is_rat(self): + return self.numerator / self.denominator + elif libisl.isl_val_is_infty(self): + return float('inf') + elif libisl.isl_val_is_neginfty(self): + return float('-inf') + else: + assert libisl.isl_val_is_nan(self) + return float('nan') + + def is_finite(self): + return bool(libisl.isl_val_is_rat(self)) + + def is_infinite(self): + return bool(libisl.isl_val_is_infty(self) or + libisl.isl_val_is_neginfty(self)) + + def is_nan(self): + return bool(libisl.isl_val_is_nan(self)) + + def __str__(self): + if libisl.isl_val_is_rat(self): + if self.denominator == 1: + return '{}'.format(self.numerator) + else: + return '{}/{}'.format(self.numerator, self.denominator) + elif libisl.isl_val_is_infty(self): + return 'Infinity' + elif libisl.isl_val_is_neginfty(self): + return '-Infinity' + else: + assert libisl.isl_val_is_nan(self) + return 'NaN' + + def __repr__(self): + return '{}({!r})'.format(self.__class__.__name__, str(self)) diff --git a/pypol/linear.py b/pypol/linear.py new file mode 100644 index 0000000..5b5d8aa --- /dev/null +++ b/pypol/linear.py @@ -0,0 +1,458 @@ + +import functools +import numbers + +from fractions import Fraction, gcd + + +__all__ = [ + 'Expression', + 'constant', 'symbol', 'symbols', + 'eq', 'le', 'lt', 'ge', 'gt', + 'Polyhedron', + 'empty', 'universe' +] + + +class Expression: + """ + This class implements linear expressions. + """ + + def __new__(cls, coefficients=None, constant=0): + if isinstance(coefficients, str): + if constant: + raise TypeError('too many arguments') + return cls.fromstring(coefficients) + self = super().__new__(cls) + self._coefficients = {} + if isinstance(coefficients, dict): + coefficients = coefficients.items() + if coefficients is not None: + for symbol, coefficient in coefficients: + if isinstance(symbol, Expression) and symbol.issymbol(): + symbol = str(symbol) + elif not isinstance(symbol, str): + raise TypeError('symbols must be strings') + if not isinstance(coefficient, numbers.Rational): + raise TypeError('coefficients must be rational numbers') + if coefficient != 0: + self._coefficients[symbol] = coefficient + if not isinstance(constant, numbers.Rational): + raise TypeError('constant must be a rational number') + self._constant = constant + return self + + def symbols(self): + yield from sorted(self._coefficients) + + @property + def dimension(self): + return len(list(self.symbols())) + + def coefficient(self, symbol): + if isinstance(symbol, Expression) and symbol.issymbol(): + symbol = str(symbol) + elif not isinstance(symbol, str): + raise TypeError('symbol must be a string') + try: + return self._coefficients[symbol] + except KeyError: + return 0 + + __getitem__ = coefficient + + def coefficients(self): + for symbol in self.symbols(): + yield symbol, self.coefficient(symbol) + + @property + def constant(self): + return self._constant + + def isconstant(self): + return len(self._coefficients) == 0 + + def values(self): + for symbol in self.symbols(): + yield self.coefficient(symbol) + yield self.constant + + def symbol(self): + if not self.issymbol(): + raise ValueError('not a symbol: {}'.format(self)) + for symbol in self.symbols(): + return symbol + + def issymbol(self): + return len(self._coefficients) == 1 and self._constant == 0 + + def __bool__(self): + return (not self.isconstant()) or bool(self.constant) + + def __pos__(self): + return self + + def __neg__(self): + return self * -1 + + def _polymorphic(func): + @functools.wraps(func) + def wrapper(self, other): + if isinstance(other, Expression): + return func(self, other) + if isinstance(other, numbers.Rational): + other = Expression(constant=other) + return func(self, other) + return NotImplemented + return wrapper + + @_polymorphic + def __add__(self, other): + coefficients = dict(self.coefficients()) + for symbol, coefficient in other.coefficients(): + if symbol in coefficients: + coefficients[symbol] += coefficient + else: + coefficients[symbol] = coefficient + constant = self.constant + other.constant + return Expression(coefficients, constant) + + __radd__ = __add__ + + @_polymorphic + def __sub__(self, other): + coefficients = dict(self.coefficients()) + for symbol, coefficient in other.coefficients(): + if symbol in coefficients: + coefficients[symbol] -= coefficient + else: + coefficients[symbol] = -coefficient + constant = self.constant - other.constant + return Expression(coefficients, constant) + + __rsub__ = __sub__ + + @_polymorphic + def __mul__(self, other): + if other.isconstant(): + coefficients = dict(self.coefficients()) + for symbol in coefficients: + coefficients[symbol] *= other.constant + constant = self.constant * other.constant + return Expression(coefficients, constant) + if isinstance(other, Expression) and not self.isconstant(): + raise ValueError('non-linear expression: ' + '{} * {}'.format(self._parenstr(), other._parenstr())) + return NotImplemented + + __rmul__ = __mul__ + + @_polymorphic + def __truediv__(self, other): + if other.isconstant(): + coefficients = dict(self.coefficients()) + for symbol in coefficients: + coefficients[symbol] = \ + Fraction(coefficients[symbol], other.constant) + constant = Fraction(self.constant, other.constant) + return Expression(coefficients, constant) + if isinstance(other, Expression): + raise ValueError('non-linear expression: ' + '{} / {}'.format(self._parenstr(), other._parenstr())) + return NotImplemented + + def __rtruediv__(self, other): + if isinstance(other, Rational): + if self.isconstant(): + constant = Fraction(other, self.constant) + return Expression(constant=constant) + else: + raise ValueError('non-linear expression: ' + '{} / {}'.format(other._parenstr(), self._parenstr())) + return NotImplemented + + def __str__(self): + string = '' + symbols = sorted(self.symbols()) + i = 0 + for symbol in symbols: + coefficient = self[symbol] + if coefficient == 1: + if i == 0: + string += symbol + else: + string += ' + {}'.format(symbol) + elif coefficient == -1: + if i == 0: + string += '-{}'.format(symbol) + else: + string += ' - {}'.format(symbol) + else: + if i == 0: + string += '{}*{}'.format(coefficient, symbol) + elif coefficient > 0: + string += ' + {}*{}'.format(coefficient, symbol) + else: + assert coefficient < 0 + coefficient *= -1 + string += ' - {}*{}'.format(coefficient, symbol) + i += 1 + constant = self.constant + if constant != 0 and i == 0: + string += '{}'.format(constant) + elif constant > 0: + string += ' + {}'.format(constant) + elif constant < 0: + constant *= -1 + string += ' - {}'.format(constant) + return string + + def _parenstr(self, always=False): + string = str(self) + if not always and (self.isconstant() or self.issymbol()): + return string + else: + return '({})'.format(string) + + def __repr__(self): + string = '{}({{'.format(self.__class__.__name__) + for i, (symbol, coefficient) in enumerate(self.coefficients()): + if i != 0: + string += ', ' + string += '{!r}: {!r}'.format(symbol, coefficient) + string += '}}, {!r})'.format(self.constant) + return string + + @classmethod + def fromstring(cls, string): + raise NotImplementedError + + @_polymorphic + def __eq__(self, other): + # "normal" equality + # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs + return isinstance(other, Expression) and \ + self._coefficients == other._coefficients and \ + self.constant == other.constant + + def __hash__(self): + return hash((self._coefficients, self._constant)) + + def _canonify(self): + lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), + [value.denominator for value in self.values()]) + return self * lcm + + @_polymorphic + def _eq(self, other): + return Polyhedron(equalities=[(self - other)._canonify()]) + + @_polymorphic + def __le__(self, other): + return Polyhedron(inequalities=[(self - other)._canonify()]) + + @_polymorphic + def __lt__(self, other): + return Polyhedron(inequalities=[(self - other)._canonify() + 1]) + + @_polymorphic + def __ge__(self, other): + return Polyhedron(inequalities=[(other - self)._canonify()]) + + @_polymorphic + def __gt__(self, other): + return Polyhedron(inequalities=[(other - self)._canonify() + 1]) + + +def constant(numerator=0, denominator=None): + return Expression(constant=Fraction(numerator, denominator)) + +def symbol(name): + if not isinstance(name, str): + raise TypeError('name must be a string') + return Expression(coefficients={name: 1}) + +def symbols(names): + if isinstance(names, str): + names = names.replace(',', ' ').split() + return (symbol(name) for name in names) + + +def _operator(func): + @functools.wraps(func) + def wrapper(a, b): + if isinstance(a, numbers.Rational): + a = constant(a) + if isinstance(b, numbers.Rational): + b = constant(b) + if isinstance(a, Expression) and isinstance(b, Expression): + return func(a, b) + raise TypeError('arguments must be linear expressions') + return wrapper + +@_operator +def eq(a, b): + return a._eq(b) + +@_operator +def le(a, b): + return a <= b + +@_operator +def lt(a, b): + return a < b + +@_operator +def ge(a, b): + return a >= b + +@_operator +def gt(a, b): + return a > b + + +class Polyhedron: + """ + This class implements polyhedrons. + """ + + def __new__(cls, equalities=None, inequalities=None): + if isinstance(equalities, str): + if inequalities is not None: + raise TypeError('too many arguments') + return cls.fromstring(equalities) + self = super().__new__(cls) + self._equalities = [] + if equalities is not None: + for constraint in equalities: + for value in constraint.values(): + if value.denominator != 1: + raise TypeError('non-integer constraint: ' + '{} == 0'.format(constraint)) + self._equalities.append(constraint) + self._inequalities = [] + if inequalities is not None: + for constraint in inequalities: + for value in constraint.values(): + if value.denominator != 1: + raise TypeError('non-integer constraint: ' + '{} <= 0'.format(constraint)) + self._inequalities.append(constraint) + return self + + @property + def equalities(self): + yield from self._equalities + + @property + def inequalities(self): + yield from self._inequalities + + def constraints(self): + yield from self.equalities + yield from self.inequalities + + def symbols(self): + s = set() + for constraint in self.constraints(): + s.update(constraint.symbols) + yield from sorted(s) + + @property + def dimension(self): + return len(self.symbols()) + + def __bool__(self): + # return false if the polyhedron is empty, true otherwise + raise NotImplementedError + + def __contains__(self, value): + # is the value in the polyhedron? + raise NotImplementedError + + def __eq__(self, other): + raise NotImplementedError + + def isempty(self): + return self == empty + + def isuniverse(self): + return self == universe + + def isdisjoint(self, other): + # return true if the polyhedron has no elements in common with other + raise NotImplementedError + + def issubset(self, other): + raise NotImplementedError + + def __le__(self, other): + return self.issubset(other) + + def __lt__(self, other): + raise NotImplementedError + + def issuperset(self, other): + # test whether every element in other is in the polyhedron + raise NotImplementedError + + def __ge__(self, other): + return self.issuperset(other) + + def __gt__(self, other): + raise NotImplementedError + + def union(self, *others): + # return a new polyhedron with elements from the polyhedron and all + # others (convex union) + raise NotImplementedError + + def __or__(self, other): + return self.union(other) + + def intersection(self, *others): + # return a new polyhedron with elements common to the polyhedron and all + # others + # a poor man's implementation could be: + # equalities = list(self.equalities) + # inequalities = list(self.inequalities) + # for other in others: + # equalities.extend(other.equalities) + # inequalities.extend(other.inequalities) + # return self.__class__(equalities, inequalities) + raise NotImplementedError + + def __and__(self, other): + return self.intersection(other) + + def difference(self, *others): + # return a new polyhedron with elements in the polyhedron that are not + # in the others + raise NotImplementedError + + def __sub__(self, other): + return self.difference(other) + + def __str__(self): + constraints = [] + for constraint in self.equalities: + constraints.append('{} == 0'.format(constraint)) + for constraint in self.inequalities: + constraints.append('{} <= 0'.format(constraint)) + return '{{{}}}'.format(', '.join(constraints)) + + def __repr__(self): + equalities = list(self.equalities) + inequalities = list(self.inequalities) + return '{}(equalities={!r}, inequalities={!r})' \ + ''.format(self.__class__.__name__, equalities, inequalities) + + @classmethod + def fromstring(cls, string): + raise NotImplementedError + + +empty = le(1, 0) + +universe = Polyhedron() diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..a0dfb8a --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +from distutils.core import setup + +setup( + name='pypol', + description='A polyhedral library based on ISL', + author='MINES ParisTech', + packages=['pypol'] +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_isl.py b/tests/test_isl.py new file mode 100644 index 0000000..7186dad --- /dev/null +++ b/tests/test_isl.py @@ -0,0 +1,199 @@ + +import unittest + +from math import floor, ceil, trunc + +from pypol.isl import * + + +class TestContext(unittest.TestCase): + + def test_eq(self): + ctx1, ctx2 = Context(), Context() + self.assertEqual(ctx1, ctx1) + self.assertNotEqual(ctx1, ctx2) + + +class TestValue(unittest.TestCase): + + def setUp(self): + self.context = Context() + self.zero = Value(self.context) + self.nan = Value(self.context, 'NaN') + self.inf = Value(self.context, 'Inf') + self.neginf = Value(self.context, '-Inf') + self.answer = Value(self.context, 42) + self.pi = Value(self.context, 22, 7) + + def test_init(self): + self.assertEqual(Value(self.context, 42), self.answer) + self.assertEqual(Value(self.context, '42'), self.answer) + self.assertEqual(Value(self.context, 22, 7), self.pi) + self.assertEqual(Value(self.context, '-22/7'), -self.pi) + self.assertTrue(Value(self.context, 'nan').is_nan()) + self.assertTrue(Value(self.context, '-nan').is_nan()) + self.assertTrue(Value(self.context, 'NaN').is_nan()) + self.assertEqual(Value(self.context, '-inf'), self.neginf) + self.assertEqual(Value(self.context, '-Infinity'), self.neginf) + + def test_numerator(self): + self.assertEqual(self.zero.numerator, 0) + self.assertEqual(self.answer.numerator, 42) + self.assertEqual(self.pi.numerator, 22) + with self.assertRaises(ValueError): + self.nan.numerator + with self.assertRaises(ValueError): + self.inf.numerator + + def test_denominator(self): + self.assertEqual(self.zero.denominator, 1) + self.assertEqual(self.answer.denominator, 1) + self.assertEqual(self.pi.denominator, 7) + with self.assertRaises(ValueError): + self.nan.denominator + with self.assertRaises(ValueError): + self.inf.denominator + + def test_bool(self): + self.assertFalse(self.zero) + self.assertTrue(self.answer) + self.assertTrue(self.pi) + self.assertEqual(bool(self.nan), bool(float('nan'))) + self.assertEqual(bool(self.inf), bool(float('inf'))) + + def test_lt(self): + self.assertTrue(self.neginf < self.zero) + self.assertTrue(self.zero < self.pi) + self.assertTrue(self.pi < self.answer) + self.assertTrue(self.answer < self.inf) + self.assertFalse(self.nan < self.answer) + self.assertFalse(self.nan < self.inf) + self.assertFalse(self.nan < self.neginf) + self.assertTrue(self.neginf < self.inf) + + def test_le(self): + self.assertTrue(self.pi <= self.pi) + self.assertTrue(self.pi <= self.answer) + self.assertFalse(self.answer <= self.pi) + + def test_gt(self): + self.assertFalse(self.pi > self.pi) + self.assertTrue(self.answer > self.pi) + self.assertFalse(self.pi > self.answer) + + def test_ge(self): + self.assertTrue(self.pi >= self.pi) + self.assertTrue(self.answer >= self.pi) + self.assertFalse(self.pi >= self.answer) + + def test_eq(self): + self.assertEqual(self.pi, self.pi) + self.assertEqual(self.inf, self.inf) + self.assertNotEqual(self.neginf, self.inf) + self.assertNotEqual(self.nan, self.nan) + self.assertEqual(self.zero, 0) + self.assertEqual(0, self.zero) + self.assertEqual(self.pi, Fraction(22, 7)) + self.assertEqual(Fraction(22, 7), self.pi) + with self.assertRaises(TypeError): + self.zero == 0. + + def test_ne(self): + self.assertTrue(self.pi != self.answer) + self.assertFalse(self.pi != self.pi) + self.assertTrue(self.neginf != self.inf) + self.assertTrue(self.nan != self.nan) + + def test_abs(self): + self.assertEqual(abs(self.pi), self.pi) + self.assertEqual(abs(self.neginf), self.inf) + self.assertEqual(abs(-self.pi), self.pi) + self.assertTrue(abs(self.nan).is_nan()) + + def test_pos(self): + self.assertEqual(+self.pi, self.pi) + + def test_neg(self): + self.assertEqual(-self.neginf, self.inf) + self.assertEqual(-(-self.pi), self.pi) + + def test_floor(self): + self.assertEqual(floor(self.pi), Value(self.context, 3)) + self.assertEqual(floor(-self.pi), Value(self.context, -4)) + # not float behavior, but makes sense + self.assertEqual(floor(self.inf), self.inf) + self.assertTrue(floor(self.nan).is_nan()) + + def test_ceil(self): + self.assertEqual(ceil(self.pi), Value(self.context, 4)) + self.assertRaises(ceil(-self.pi) == Value(self.context, -3)) + + def test_trunc(self): + self.assertEqual(trunc(self.pi), Value(self.context, 3)) + self.assertEqual(trunc(-self.pi), Value(self.context, -3)) + + def test_add(self): + self.assertEqual(self.answer + self.answer, Value(self.context, 84)) + self.assertEqual(self.answer + self.pi, Value(self.context, 316, 7)) + self.assertEqual(self.pi + self.pi, Value(self.context, 44, 7)) + self.assertEqual(self.pi + self.neginf, self.neginf) + self.assertEqual(self.pi + self.inf, self.inf) + self.assertTrue((self.pi + self.nan).is_nan()) + self.assertTrue((self.inf + self.nan).is_nan()) + self.assertTrue((self.inf + self.neginf).is_nan()) + self.assertEqual(self.pi + 42, Value(self.context, 316, 7)) + self.assertEqual(42 + self.pi, Value(self.context, 316, 7)) + self.assertEqual(self.pi + Fraction(22, 7), Value(self.context, 44, 7)) + with self.assertRaises(TypeError): + self.pi + float(42) + + def test_sub(self): + self.assertEqual(self.answer - self.pi, Value(self.context, 272, 7)) + + def test_mul(self): + self.assertEqual(Value(self.context, 6) * Value(self.context, 7), self.answer) + self.assertNotEqual(Value(self.context, 6) * Value(self.context, 9), self.answer) + self.assertEqual(self.inf * Value(self.context, 2), self.inf) + self.assertEqual(self.inf * Value(self.context, -2), self.neginf) + self.assertTrue((self.nan * Value(self.context, 2)).is_nan()) + self.assertTrue((self.nan * self.inf).is_nan()) + + def test_div(self): + self.assertEqual(Value(self.context, 22) / Value(self.context, 7), self.pi) + self.assertEqual(self.pi / self.pi, Value(self.context, 1)) + # not float behavior, but makes sense + self.assertTrue((self.pi / self.zero).is_nan()) + + def test_float(self): + self.assertAlmostEqual(float(Value(self.context, 1, 2)), 0.5) + self.assertTrue(math.isnan(float(Value(self.context, 'NaN')))) + self.assertAlmostEqual(float(Value(self.context, 'Inf')), float('inf')) + + def test_is_finite(self): + self.assertTrue(self.pi.is_finite()) + self.assertFalse(self.inf.is_finite()) + self.assertFalse(self.nan.is_finite()) + + def test_is_infinite(self): + self.assertFalse(self.pi.is_infinite()) + self.assertTrue(self.inf.is_infinite()) + self.assertFalse(self.nan.is_infinite()) + + def test_is_nan(self): + self.assertFalse(self.pi.is_nan()) + self.assertFalse(self.inf.is_nan()) + self.assertTrue(self.nan.is_nan()) + + def test_str(self): + self.assertEqual(str(self.answer), '42') + self.assertEqual(str(self.pi), '22/7') + self.assertEqual(str(self.nan), 'NaN') + self.assertEqual(str(self.inf), 'Infinity') + self.assertEqual(str(self.neginf), '-Infinity') + + def test_repr(self): + self.assertEqual(repr(self.answer), "Value('42')") + self.assertEqual(repr(self.pi), "Value('22/7')") + self.assertEqual(repr(self.nan), "Value('NaN')") + self.assertEqual(repr(self.inf), "Value('Infinity')") + self.assertEqual(repr(self.neginf), "Value('-Infinity')") diff --git a/tests/test_linear.py b/tests/test_linear.py new file mode 100644 index 0000000..3912fd3 --- /dev/null +++ b/tests/test_linear.py @@ -0,0 +1,176 @@ + +import unittest + +from fractions import Fraction + +from pypol.linear import * + + +class TestExpression(unittest.TestCase): + + def setUp(self): + self.x = symbol('x') + self.y = symbol('y') + self.z = symbol('z') + self.zero = constant(0) + self.pi = constant(Fraction(22, 7)) + self.e = self.x - 2*self.y + 3 + + def test_new(self): + pass + + def test_symbols(self): + self.assertCountEqual(self.x.symbols(), ['x']) + self.assertCountEqual(self.pi.symbols(), []) + self.assertCountEqual(self.e.symbols(), ['x', 'y']) + + def test_dimension(self): + self.assertEqual(self.x.dimension, 1) + self.assertEqual(self.pi.dimension, 0) + self.assertEqual(self.e.dimension, 2) + + def test_coefficient(self): + self.assertEqual(self.e.coefficient('x'), 1) + self.assertEqual(self.e.coefficient('y'), -2) + self.assertEqual(self.e.coefficient(self.y), -2) + self.assertEqual(self.e.coefficient('z'), 0) + with self.assertRaises(TypeError): + self.e.coefficient(0) + with self.assertRaises(TypeError): + self.e.coefficient(self.e) + + def test_getitem(self): + self.assertEqual(self.e['x'], 1) + self.assertEqual(self.e['y'], -2) + self.assertEqual(self.e[self.y], -2) + self.assertEqual(self.e['z'], 0) + with self.assertRaises(TypeError): + self.e[0] + with self.assertRaises(TypeError): + self.e[self.e] + + def test_coefficients(self): + self.assertCountEqual(self.e.coefficients(), [('x', 1), ('y', -2)]) + + def test_constant(self): + self.assertEqual(self.x.constant, 0) + self.assertEqual(self.pi.constant, Fraction(22, 7)) + self.assertEqual(self.e.constant, 3) + + def test_isconstant(self): + self.assertFalse(self.x.isconstant()) + self.assertTrue(self.pi.isconstant()) + self.assertFalse(self.e.isconstant()) + + def test_values(self): + self.assertCountEqual(self.e.values(), [1, -2, 3]) + + def test_symbol(self): + self.assertEqual(self.x.symbol(), 'x') + with self.assertRaises(ValueError): + self.pi.symbol() + with self.assertRaises(ValueError): + self.e.symbol() + + def test_issymbol(self): + self.assertTrue(self.x.issymbol()) + self.assertFalse(self.pi.issymbol()) + self.assertFalse(self.e.issymbol()) + + def test_bool(self): + self.assertTrue(self.x) + self.assertFalse(self.zero) + self.assertTrue(self.pi) + self.assertTrue(self.e) + + def test_pos(self): + self.assertEqual(+self.e, self.e) + + def test_neg(self): + self.assertEqual(-self.e, -self.x + 2*self.y - 3) + + def test_add(self): + self.assertEqual(self.x + Fraction(22, 7), self.x + self.pi) + self.assertEqual(Fraction(22, 7) + self.x, self.x + self.pi) + self.assertEqual(self.x + self.x, 2 * self.x) + self.assertEqual(self.e + 2*self.y, self.x + 3) + + def test_sub(self): + self.assertEqual(self.x - self.x, 0) + self.assertEqual(self.e - 3, self.x - 2*self.y) + + def test_mul(self): + self.assertEqual(self.pi * 7, 22) + self.assertEqual(self.e * 0, 0) + self.assertEqual(self.e * 2, 2*self.x - 4*self.y + 6) + + def test_div(self): + with self.assertRaises(ZeroDivisionError): + self.e / 0 + self.assertEqual(self.e / 2, self.x / 2 - self.y + Fraction(3, 2)) + + def test_str(self): + self.assertEqual(str(self.x), 'x') + self.assertEqual(str(-self.x), '-x') + self.assertEqual(str(self.pi), '22/7') + self.assertEqual(str(self.e), 'x - 2*y + 3') + + def test_repr(self): + self.assertEqual(repr(self.e), "Expression({'x': 1, 'y': -2}, 3)") + + @unittest.expectedFailure + def test_fromstring(self): + self.assertEqual(Expression.fromstring('x'), self.x) + self.assertEqual(Expression.fromstring('-x'), -self.x) + self.assertEqual(Expression.fromstring('22/7'), self.pi) + self.assertEqual(Expression.fromstring('x - 2y + 3'), self.e) + self.assertEqual(Expression.fromstring('x - (3-1)y + 3'), self.e) + self.assertEqual(Expression.fromstring('x - 2*y + 3'), self.e) + + def test_eq(self): + self.assertEqual(self.e, self.e) + self.assertNotEqual(self.x, self.y) + self.assertEqual(self.zero, 0) + + def test_canonify(self): + self.assertEqual((self.x + self.y/2 + self.z/3)._canonify(), + 6*self.x + 3*self.y + 2*self.z) + + +class TestHelpers(unittest.TestCase): + + def setUp(self): + self.x = symbol('x') + self.y = symbol('y') + + def test_constant(self): + self.assertEqual(constant(3), 3) + self.assertEqual(constant('3'), 3) + self.assertEqual(constant(Fraction(3, 4)), Fraction(3, 4)) + self.assertEqual(constant('3/4'), Fraction(3, 4)) + with self.assertRaises(ValueError): + constant('a') + with self.assertRaises(TypeError): + constant([]) + + def test_symbol(self): + self.assertEqual(symbol('x'), self.x) + self.assertNotEqual(symbol('y'), self.x) + with self.assertRaises(TypeError): + symbol(0) + + def test_symbols(self): + self.assertListEqual(list(symbols('x y')), [self.x, self.y]) + self.assertListEqual(list(symbols('x,y')), [self.x, self.y]) + self.assertListEqual(list(symbols(['x', 'y'])), [self.x, self.y]) + + +class TestOperators(unittest.TestCase): + + pass + + +class TestPolyhedron(unittest.TestCase): + + pass + -- 2.20.1