+
+import functools
+import numbers
+
+from fractions import Fraction, gcd
+
+
+__all__ = [
+ 'Expression',
+ 'constant', 'symbol', 'symbols',
+ 'eq', 'le', 'lt', 'ge', 'gt',
+ 'Polyhedron',
+ 'empty', 'universe'
+]
+
+
+class Expression:
+ """
+ This class implements linear expressions.
+ """
+
+ def __new__(cls, coefficients=None, constant=0):
+ if isinstance(coefficients, str):
+ if constant:
+ raise TypeError('too many arguments')
+ return cls.fromstring(coefficients)
+ self = super().__new__(cls)
+ self._coefficients = {}
+ if isinstance(coefficients, dict):
+ coefficients = coefficients.items()
+ if coefficients is not None:
+ for symbol, coefficient in coefficients:
+ if isinstance(symbol, Expression) and symbol.issymbol():
+ symbol = str(symbol)
+ elif not isinstance(symbol, str):
+ raise TypeError('symbols must be strings')
+ if not isinstance(coefficient, numbers.Rational):
+ raise TypeError('coefficients must be rational numbers')
+ if coefficient != 0:
+ self._coefficients[symbol] = coefficient
+ if not isinstance(constant, numbers.Rational):
+ raise TypeError('constant must be a rational number')
+ self._constant = constant
+ return self
+
+ def symbols(self):
+ yield from sorted(self._coefficients)
+
+ @property
+ def dimension(self):
+ return len(list(self.symbols()))
+
+ def coefficient(self, symbol):
+ if isinstance(symbol, Expression) and symbol.issymbol():
+ symbol = str(symbol)
+ elif not isinstance(symbol, str):
+ raise TypeError('symbol must be a string')
+ try:
+ return self._coefficients[symbol]
+ except KeyError:
+ return 0
+
+ __getitem__ = coefficient
+
+ def coefficients(self):
+ for symbol in self.symbols():
+ yield symbol, self.coefficient(symbol)
+
+ @property
+ def constant(self):
+ return self._constant
+
+ def isconstant(self):
+ return len(self._coefficients) == 0
+
+ def values(self):
+ for symbol in self.symbols():
+ yield self.coefficient(symbol)
+ yield self.constant
+
+ def symbol(self):
+ if not self.issymbol():
+ raise ValueError('not a symbol: {}'.format(self))
+ for symbol in self.symbols():
+ return symbol
+
+ def issymbol(self):
+ return len(self._coefficients) == 1 and self._constant == 0
+
+ def __bool__(self):
+ return (not self.isconstant()) or bool(self.constant)
+
+ def __pos__(self):
+ return self
+
+ def __neg__(self):
+ return self * -1
+
+ def _polymorphic(func):
+ @functools.wraps(func)
+ def wrapper(self, other):
+ if isinstance(other, Expression):
+ return func(self, other)
+ if isinstance(other, numbers.Rational):
+ other = Expression(constant=other)
+ return func(self, other)
+ return NotImplemented
+ return wrapper
+
+ @_polymorphic
+ def __add__(self, other):
+ coefficients = dict(self.coefficients())
+ for symbol, coefficient in other.coefficients():
+ if symbol in coefficients:
+ coefficients[symbol] += coefficient
+ else:
+ coefficients[symbol] = coefficient
+ constant = self.constant + other.constant
+ return Expression(coefficients, constant)
+
+ __radd__ = __add__
+
+ @_polymorphic
+ def __sub__(self, other):
+ coefficients = dict(self.coefficients())
+ for symbol, coefficient in other.coefficients():
+ if symbol in coefficients:
+ coefficients[symbol] -= coefficient
+ else:
+ coefficients[symbol] = -coefficient
+ constant = self.constant - other.constant
+ return Expression(coefficients, constant)
+
+ __rsub__ = __sub__
+
+ @_polymorphic
+ def __mul__(self, other):
+ if other.isconstant():
+ coefficients = dict(self.coefficients())
+ for symbol in coefficients:
+ coefficients[symbol] *= other.constant
+ constant = self.constant * other.constant
+ return Expression(coefficients, constant)
+ if isinstance(other, Expression) and not self.isconstant():
+ raise ValueError('non-linear expression: '
+ '{} * {}'.format(self._parenstr(), other._parenstr()))
+ return NotImplemented
+
+ __rmul__ = __mul__
+
+ @_polymorphic
+ def __truediv__(self, other):
+ if other.isconstant():
+ coefficients = dict(self.coefficients())
+ for symbol in coefficients:
+ coefficients[symbol] = \
+ Fraction(coefficients[symbol], other.constant)
+ constant = Fraction(self.constant, other.constant)
+ return Expression(coefficients, constant)
+ if isinstance(other, Expression):
+ raise ValueError('non-linear expression: '
+ '{} / {}'.format(self._parenstr(), other._parenstr()))
+ return NotImplemented
+
+ def __rtruediv__(self, other):
+ if isinstance(other, Rational):
+ if self.isconstant():
+ constant = Fraction(other, self.constant)
+ return Expression(constant=constant)
+ else:
+ raise ValueError('non-linear expression: '
+ '{} / {}'.format(other._parenstr(), self._parenstr()))
+ return NotImplemented
+
+ def __str__(self):
+ string = ''
+ symbols = sorted(self.symbols())
+ i = 0
+ for symbol in symbols:
+ coefficient = self[symbol]
+ if coefficient == 1:
+ if i == 0:
+ string += symbol
+ else:
+ string += ' + {}'.format(symbol)
+ elif coefficient == -1:
+ if i == 0:
+ string += '-{}'.format(symbol)
+ else:
+ string += ' - {}'.format(symbol)
+ else:
+ if i == 0:
+ string += '{}*{}'.format(coefficient, symbol)
+ elif coefficient > 0:
+ string += ' + {}*{}'.format(coefficient, symbol)
+ else:
+ assert coefficient < 0
+ coefficient *= -1
+ string += ' - {}*{}'.format(coefficient, symbol)
+ i += 1
+ constant = self.constant
+ if constant != 0 and i == 0:
+ string += '{}'.format(constant)
+ elif constant > 0:
+ string += ' + {}'.format(constant)
+ elif constant < 0:
+ constant *= -1
+ string += ' - {}'.format(constant)
+ return string
+
+ def _parenstr(self, always=False):
+ string = str(self)
+ if not always and (self.isconstant() or self.issymbol()):
+ return string
+ else:
+ return '({})'.format(string)
+
+ def __repr__(self):
+ string = '{}({{'.format(self.__class__.__name__)
+ for i, (symbol, coefficient) in enumerate(self.coefficients()):
+ if i != 0:
+ string += ', '
+ string += '{!r}: {!r}'.format(symbol, coefficient)
+ string += '}}, {!r})'.format(self.constant)
+ return string
+
+ @classmethod
+ def fromstring(cls, string):
+ raise NotImplementedError
+
+ @_polymorphic
+ def __eq__(self, other):
+ # "normal" equality
+ # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
+ return isinstance(other, Expression) and \
+ self._coefficients == other._coefficients and \
+ self.constant == other.constant
+
+ def __hash__(self):
+ return hash((self._coefficients, self._constant))
+
+ def _canonify(self):
+ lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
+ [value.denominator for value in self.values()])
+ return self * lcm
+
+ @_polymorphic
+ def _eq(self, other):
+ return Polyhedron(equalities=[(self - other)._canonify()])
+
+ @_polymorphic
+ def __le__(self, other):
+ return Polyhedron(inequalities=[(self - other)._canonify()])
+
+ @_polymorphic
+ def __lt__(self, other):
+ return Polyhedron(inequalities=[(self - other)._canonify() + 1])
+
+ @_polymorphic
+ def __ge__(self, other):
+ return Polyhedron(inequalities=[(other - self)._canonify()])
+
+ @_polymorphic
+ def __gt__(self, other):
+ return Polyhedron(inequalities=[(other - self)._canonify() + 1])
+
+
+def constant(numerator=0, denominator=None):
+ return Expression(constant=Fraction(numerator, denominator))
+
+def symbol(name):
+ if not isinstance(name, str):
+ raise TypeError('name must be a string')
+ return Expression(coefficients={name: 1})
+
+def symbols(names):
+ if isinstance(names, str):
+ names = names.replace(',', ' ').split()
+ return (symbol(name) for name in names)
+
+
+def _operator(func):
+ @functools.wraps(func)
+ def wrapper(a, b):
+ if isinstance(a, numbers.Rational):
+ a = constant(a)
+ if isinstance(b, numbers.Rational):
+ b = constant(b)
+ if isinstance(a, Expression) and isinstance(b, Expression):
+ return func(a, b)
+ raise TypeError('arguments must be linear expressions')
+ return wrapper
+
+@_operator
+def eq(a, b):
+ return a._eq(b)
+
+@_operator
+def le(a, b):
+ return a <= b
+
+@_operator
+def lt(a, b):
+ return a < b
+
+@_operator
+def ge(a, b):
+ return a >= b
+
+@_operator
+def gt(a, b):
+ return a > b
+
+
+class Polyhedron:
+ """
+ This class implements polyhedrons.
+ """
+
+ def __new__(cls, equalities=None, inequalities=None):
+ if isinstance(equalities, str):
+ if inequalities is not None:
+ raise TypeError('too many arguments')
+ return cls.fromstring(equalities)
+ self = super().__new__(cls)
+ self._equalities = []
+ if equalities is not None:
+ for constraint in equalities:
+ for value in constraint.values():
+ if value.denominator != 1:
+ raise TypeError('non-integer constraint: '
+ '{} == 0'.format(constraint))
+ self._equalities.append(constraint)
+ self._inequalities = []
+ if inequalities is not None:
+ for constraint in inequalities:
+ for value in constraint.values():
+ if value.denominator != 1:
+ raise TypeError('non-integer constraint: '
+ '{} <= 0'.format(constraint))
+ self._inequalities.append(constraint)
+ return self
+
+ @property
+ def equalities(self):
+ yield from self._equalities
+
+ @property
+ def inequalities(self):
+ yield from self._inequalities
+
+ def constraints(self):
+ yield from self.equalities
+ yield from self.inequalities
+
+ def symbols(self):
+ s = set()
+ for constraint in self.constraints():
+ s.update(constraint.symbols)
+ yield from sorted(s)
+
+ @property
+ def dimension(self):
+ return len(self.symbols())
+
+ def __bool__(self):
+ # return false if the polyhedron is empty, true otherwise
+ raise NotImplementedError
+
+ def __contains__(self, value):
+ # is the value in the polyhedron?
+ raise NotImplementedError
+
+ def __eq__(self, other):
+ raise NotImplementedError
+
+ def isempty(self):
+ return self == empty
+
+ def isuniverse(self):
+ return self == universe
+
+ def isdisjoint(self, other):
+ # return true if the polyhedron has no elements in common with other
+ raise NotImplementedError
+
+ def issubset(self, other):
+ raise NotImplementedError
+
+ def __le__(self, other):
+ return self.issubset(other)
+
+ def __lt__(self, other):
+ raise NotImplementedError
+
+ def issuperset(self, other):
+ # test whether every element in other is in the polyhedron
+ raise NotImplementedError
+
+ def __ge__(self, other):
+ return self.issuperset(other)
+
+ def __gt__(self, other):
+ raise NotImplementedError
+
+ def union(self, *others):
+ # return a new polyhedron with elements from the polyhedron and all
+ # others (convex union)
+ raise NotImplementedError
+
+ def __or__(self, other):
+ return self.union(other)
+
+ def intersection(self, *others):
+ # return a new polyhedron with elements common to the polyhedron and all
+ # others
+ # a poor man's implementation could be:
+ # equalities = list(self.equalities)
+ # inequalities = list(self.inequalities)
+ # for other in others:
+ # equalities.extend(other.equalities)
+ # inequalities.extend(other.inequalities)
+ # return self.__class__(equalities, inequalities)
+ raise NotImplementedError
+
+ def __and__(self, other):
+ return self.intersection(other)
+
+ def difference(self, *others):
+ # return a new polyhedron with elements in the polyhedron that are not
+ # in the others
+ raise NotImplementedError
+
+ def __sub__(self, other):
+ return self.difference(other)
+
+ def __str__(self):
+ constraints = []
+ for constraint in self.equalities:
+ constraints.append('{} == 0'.format(constraint))
+ for constraint in self.inequalities:
+ constraints.append('{} <= 0'.format(constraint))
+ return '{{{}}}'.format(', '.join(constraints))
+
+ def __repr__(self):
+ equalities = list(self.equalities)
+ inequalities = list(self.inequalities)
+ return '{}(equalities={!r}, inequalities={!r})' \
+ ''.format(self.__class__.__name__, equalities, inequalities)
+
+ @classmethod
+ def fromstring(cls, string):
+ raise NotImplementedError
+
+
+empty = le(1, 0)
+
+universe = Polyhedron()