Commit 28f354bf by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] Add expr_visitor, fix expr_functor exponential blowup problem (#2988)

* save

* lint
parent 5d70b008
...@@ -101,6 +101,7 @@ sequential_pass = ir_pass.sequential_pass ...@@ -101,6 +101,7 @@ sequential_pass = ir_pass.sequential_pass
# ExprFunctor # ExprFunctor
ExprFunctor = expr_functor.ExprFunctor ExprFunctor = expr_functor.ExprFunctor
ExprVisitor = expr_functor.ExprVisitor
ExprMutator = expr_functor.ExprMutator ExprMutator = expr_functor.ExprMutator
# Parser # Parser
......
...@@ -36,9 +36,8 @@ class ExprFunctor: ...@@ -36,9 +36,8 @@ class ExprFunctor:
# pylint: disable=no-else-return # pylint: disable=no-else-return
def visit(self, expr): def visit(self, expr):
"""Apply the visitor to an expression.""" """Apply the visitor to an expression."""
found = self.memo_map.get(expr) if expr in self.memo_map:
if found: return self.memo_map[expr]
return found
if isinstance(expr, Function): if isinstance(expr, Function):
res = self.visit_function(expr) res = self.visit_function(expr)
...@@ -126,6 +125,68 @@ class ExprFunctor: ...@@ -126,6 +125,68 @@ class ExprFunctor:
raise NotImplementedError() raise NotImplementedError()
class ExprVisitor(ExprFunctor):
"""
A visitor over Expr.
The default behavior recursively traverses the AST.
"""
def visit_tuple(self, t):
for x in t.fields:
self.visit(x)
def visit_call(self, c):
self.visit(c.op)
for a in c.args:
self.visit(a)
def visit_var(self, v):
pass
def visit_let(self, l):
self.visit(l.var)
self.visit(l.value)
self.visit(l.body)
def visit_function(self, f):
self.visit(f.body)
def visit_if(self, i):
self.visit(i.cond)
self.visit(i.true_branch)
self.visit(i.false_branch)
def visit_global_var(self, gv):
pass
def visit_constructor(self, c):
pass
def visit_op(self, op):
pass
def visit_constant(self, const):
pass
def visit_ref_create(self, r):
self.visit(r.value)
def visit_ref_read(self, r):
self.visit(r.ref)
def visit_ref_write(self, r):
self.visit(r.ref)
self.visit(r.value)
def visit_tuple_getitem(self, t):
self.visit(t.tuple_value)
def visit_match(self, m):
self.visit(m.data)
for c in m.clause:
self.visit(c.rhs)
class ExprMutator(ExprFunctor): class ExprMutator(ExprFunctor):
""" """
A functional visitor over Expr. A functional visitor over Expr.
......
...@@ -16,34 +16,42 @@ ...@@ -16,34 +16,42 @@
# under the License. # under the License.
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import ExprFunctor, ExprMutator from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor
def check_visit(expr): def check_visit(expr):
ef = ExprFunctor()
try: try:
ef = ExprFunctor()
ef.visit(expr) ef.visit(expr)
assert False assert False
except NotImplementedError: except NotImplementedError:
pass pass
ev = ExprVisitor()
ev.visit(expr)
em = ExprMutator() em = ExprMutator()
assert em.visit(expr) assert em.visit(expr)
def test_constant(): def test_constant():
check_visit(relay.const(1.0)) check_visit(relay.const(1.0))
def test_tuple(): def test_tuple():
t = relay.Tuple([relay.var('x', shape=())]) t = relay.Tuple([relay.var('x', shape=())])
check_visit(t) check_visit(t)
def test_var(): def test_var():
v = relay.var('x', shape=()) v = relay.var('x', shape=())
check_visit(v) check_visit(v)
def test_global(): def test_global():
v = relay.GlobalVar('f') v = relay.GlobalVar('f')
check_visit(v) check_visit(v)
def test_function(): def test_function():
x = relay.var('x', shape=()) x = relay.var('x', shape=())
y = relay.var('y', shape=()) y = relay.var('y', shape=())
...@@ -61,12 +69,14 @@ def test_function(): ...@@ -61,12 +69,14 @@ def test_function():
) )
check_visit(f) check_visit(f)
def test_call(): def test_call():
x = relay.var('x', shape=()) x = relay.var('x', shape=())
y = relay.var('y', shape=()) y = relay.var('y', shape=())
call = relay.op.add(x, y) call = relay.op.add(x, y)
check_visit(call) check_visit(call)
def test_let(): def test_let():
x = relay.var('x', shape=()) x = relay.var('x', shape=())
value = relay.const(2.0) value = relay.const(2.0)
...@@ -74,30 +84,43 @@ def test_let(): ...@@ -74,30 +84,43 @@ def test_let():
l = relay.Let(x, value, body) l = relay.Let(x, value, body)
check_visit(l) check_visit(l)
def test_ite(): def test_ite():
cond = relay.var('x', shape=(), dtype='bool') cond = relay.var('x', shape=(), dtype='bool')
ite = relay.If(cond, cond, cond) ite = relay.If(cond, cond, cond)
check_visit(ite) check_visit(ite)
def test_get_item(): def test_get_item():
t = relay.Tuple([relay.var('x', shape=())]) t = relay.Tuple([relay.var('x', shape=())])
t = relay.TupleGetItem(t, 0) t = relay.TupleGetItem(t, 0)
check_visit(t) check_visit(t)
def test_ref_create(): def test_ref_create():
r = relay.expr.RefCreate(relay.const(1.0)) r = relay.expr.RefCreate(relay.const(1.0))
check_visit(r) check_visit(r)
def test_ref_read(): def test_ref_read():
ref = relay.expr.RefCreate(relay.const(1.0)) ref = relay.expr.RefCreate(relay.const(1.0))
r = relay.expr.RefRead(ref) r = relay.expr.RefRead(ref)
check_visit(r) check_visit(r)
def test_ref_write(): def test_ref_write():
ref = relay.expr.RefCreate(relay.const(1.0)) ref = relay.expr.RefCreate(relay.const(1.0))
r = relay.expr.RefWrite(ref, relay.const(2.0)) r = relay.expr.RefWrite(ref, relay.const(2.0))
check_visit(r) check_visit(r)
def test_memo():
expr = relay.const(1)
for _ in range(100):
expr = expr + expr
check_visit(expr)
if __name__ == "__main__": if __name__ == "__main__":
test_constant() test_constant()
test_tuple() test_tuple()
...@@ -110,3 +133,4 @@ if __name__ == "__main__": ...@@ -110,3 +133,4 @@ if __name__ == "__main__":
test_ref_create() test_ref_create()
test_ref_read() test_ref_read()
test_ref_write() test_ref_write()
test_memo()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment