Commit fc4ba796 by Haichen Shen

add expr simplify and canonical

parent 77345051
"""Base class of symbolic expression""" """Base class of symbolic expression"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from numbers import Number as _Number from numbers import Number as _Number
from . import op as _op
from . import var_name as _name from . import var_name as _name
__addop__ = None
__subop__ = None
__mulop__ = None
__divop__ = None
class Expr(object): class Expr(object):
"""Base class of expression. """Base class of expression.
...@@ -20,28 +24,28 @@ class Expr(object): ...@@ -20,28 +24,28 @@ class Expr(object):
return () return ()
def __add__(self, other): def __add__(self, other):
return BinaryOpExpr(_op.add, self, other) return BinaryOpExpr(__addop__, self, other)
def __radd__(self, other): def __radd__(self, other):
return self.__add__(other) return self.__add__(other)
def __sub__(self, other): def __sub__(self, other):
return BinaryOpExpr(_op.sub, self, other) return BinaryOpExpr(__subop__, self, other)
def __rsub__(self, other): def __rsub__(self, other):
return BinaryOpExpr(_op.sub, other, self) return BinaryOpExpr(__subop__, other, self)
def __mul__(self, other): def __mul__(self, other):
return BinaryOpExpr(_op.mul, self, other) return BinaryOpExpr(__mulop__, self, other)
def __rmul__(self, other): def __rmul__(self, other):
return BinaryOpExpr(_op.mul, other, self) return BinaryOpExpr(__mulop__, other, self)
def __div__(self, other): def __div__(self, other):
return BinaryOpExpr(_op.div, self, other) return BinaryOpExpr(__divop__, self, other)
def __rdiv__(self, other): def __rdiv__(self, other):
return BinaryOpExpr(_op.div, other, self) return BinaryOpExpr(__divop__, other, self)
def __truediv__(self, other): def __truediv__(self, other):
return self.__div__(other) return self.__div__(other)
...@@ -75,7 +79,8 @@ class Var(Expr): ...@@ -75,7 +79,8 @@ class Var(Expr):
optional name to the var. optional name to the var.
""" """
def __init__(self, name=None): def __init__(self, name=None):
self.name = name if name else _name.NameManager.current.get(name) if name is None: name = 'i'
self.name = _name.NameManager.current.get(name)
class ConstExpr(Expr): class ConstExpr(Expr):
...@@ -95,7 +100,6 @@ class BinaryOpExpr(Expr): ...@@ -95,7 +100,6 @@ class BinaryOpExpr(Expr):
def children(self): def children(self):
return (self.lhs, self.rhs) return (self.lhs, self.rhs)
_op.binary_op_cls = BinaryOpExpr
class UnaryOpExpr(Expr): class UnaryOpExpr(Expr):
"""Unary operator expression.""" """Unary operator expression."""
......
"""Utilities to manipulate expression""" """Utilities to manipulate expression"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import expr as _expr from . import expr as _expr
from . import op as _op
def expr_with_new_children(e, children): def expr_with_new_children(e, children):
"""Returns same expr as e but with new children """Returns same expr as e but with new children
...@@ -48,6 +49,7 @@ def transform(e, f): ...@@ -48,6 +49,7 @@ def transform(e, f):
result : return value of f result : return value of f
The final result of transformation. The final result of transformation.
""" """
assert isinstance(e, _expr.Expr)
return f(e , [transform(c, f) for c in e.children()]) return f(e , [transform(c, f) for c in e.children()])
...@@ -77,6 +79,32 @@ def format_str(expr): ...@@ -77,6 +79,32 @@ def format_str(expr):
raise TypeError("Do not know how to handle type " + str(type(e))) raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str) return transform(expr, make_str)
def simplify(expr):
"""simplify expression
Parameters
----------
expr : Expr
Input expression
Returns
-------
e : Expr
Simplified expression
"""
def canonical(e, result_children):
if isinstance(e, _expr.BinaryOpExpr):
return e.op.canonical(result_children[0], result_children[1])
elif isinstance(e, _expr.UnaryOpExpr):
return e.op.canonical(result_children[0])
elif isinstance(e, _expr.ConstExpr):
return {_op.constant_canonical_key: e.value}
elif isinstance(e, _expr.Var):
return {e: 1}
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return _op.canonical_to_expr(transform(expr, canonical))
def bind(expr, update_dict): def bind(expr, update_dict):
"""Replace the variable in e by specification from kwarg """Replace the variable in e by specification from kwarg
......
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import expr as _expr
_binary_op_cls = None constant_canonical_key = '__constant__'
def canonical_to_expr(c):
elements = []
for k, v in sorted(c.items()):
if k == constant_canonical_key:
elements.append(_expr.const(v))
elif v == 0:
continue
elif v == 1:
elements.append(k)
else:
elements.append(k * v)
if elements:
expr = elements[0]
for i in range(1, len(elements)):
expr = expr + elements[i]
return expr
else:
return _expr.const(0)
class BinaryOp(object): class BinaryOp(object):
"""Base class of binary operator""" """Base class of binary operator"""
def __call__(self, lhs, rhs): def __call__(self, lhs, rhs):
return _binary_op_cls(self, lhs, rhs) return _expr.BinaryOpExpr(self, lhs, rhs)
class AddOp(BinaryOp): class AddOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s + %s)' % (lhs, rhs) return '(%s + %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
lhs = lhs.copy()
for k, v in rhs.items():
if k in lhs:
lhs[k] += v
else:
lhs[k] = v
return lhs
class SubOp(BinaryOp): class SubOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs) return '(%s - %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
lhs = lhs.copy()
for k, v in rhs.items():
if k in lhs:
lhs[k] -= v
else:
lhs[k] = -v
return lhs
class MulOp(BinaryOp): class MulOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s * %s)' % (lhs, rhs) return '(%s * %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
elhs = canonical_to_expr(lhs)
erhs = canonical_to_expr(rhs)
if isinstance(erhs, _expr.ConstExpr):
lhs = lhs.copy()
for k, v in lhs.items():
lhs[k] *= erhs.value
return lhs
if isinstance(elhs, _expr.ConstExpr):
rhs = rhs.copy()
for k, v in rhs.items():
rhs[k] *= elhs.value
return rhs
return {elhs * erhs: 1}
class DivOp(BinaryOp): class DivOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s / %s)' % (lhs, rhs) return '(%s / %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
erhs = canonical_to_expr(rhs)
if isinstance(erhs, _expr.ConstExpr):
lhs = lhs.copy()
for k, v in lhs.items():
lhs[k] /= erhs.value
return lhs
elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1}
class MaxOp(BinaryOp):
def format_str(self, lhs, rhs):
return 'max(%s, %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
diff = SubOp().canonical(lhs, rhs)
ediff = canonical_to_expr(diff)
if isinstance(ediff, _expr.ConstExpr):
return lhs if ediff.value >= 0 else rhs
return {MaxOp()(lhs, rhs): 1}
class MinOp(BinaryOp):
def format_str(self, lhs, rhs):
return 'min(%s, %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
diff = SubOp().canonical(lhs, rhs)
ediff = canonical_to_expr(diff)
if isinstance(ediff, _expr.ConstExpr):
return rhs if ediff.value >= 0 else lhs
return {MinOp()(lhs, rhs): 1}
add = AddOp() add = AddOp()
sub = SubOp() sub = SubOp()
mul = MulOp() mul = MulOp()
div = DivOp() div = DivOp()
max = MaxOp()
min = MinOp()
_expr.__addop__ = add
_expr.__subop__ = sub
_expr.__mulop__ = mul
_expr.__divop__ = div
...@@ -9,12 +9,24 @@ def test_bind(): ...@@ -9,12 +9,24 @@ def test_bind():
def test_basic(): def test_basic():
a= tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
c = a + b c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name) assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
def test_simplify():
a = tvm.Var('a')
b = tvm.Var('b')
e1 = a * (2 + 1) + b * 1
e2 = a * (2 + 1) - b * 1
e3 = tvm.max(a * 3.3 + 5, 3 + 3.3 * a)
e4 = a - a
assert tvm.format_str(tvm.simplify(e1)) == '((%s * 3) + %s)' % (a.name, b.name)
assert tvm.format_str(tvm.simplify(e2)) == '((%s * 3) + (%s * -1))' % (a.name, b.name)
assert tvm.format_str(tvm.simplify(e3)) == '((%s * 3.3) + 5)' % (a.name)
assert tvm.format_str(tvm.simplify(e4)) == '0'
if __name__ == "__main__": if __name__ == "__main__":
test_simplify()
test_basic() test_basic()
test_bind() test_bind()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment