Commit 28f354bf by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] Add expr_visitor, fix expr_functor exponential blowup problem (#2988)

* save

* lint
parent 5d70b008
......@@ -101,6 +101,7 @@ sequential_pass = ir_pass.sequential_pass
# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprVisitor = expr_functor.ExprVisitor
ExprMutator = expr_functor.ExprMutator
# Parser
......
......@@ -36,9 +36,8 @@ class ExprFunctor:
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if expr in self.memo_map:
return self.memo_map[expr]
if isinstance(expr, Function):
res = self.visit_function(expr)
......@@ -126,6 +125,68 @@ class ExprFunctor:
raise NotImplementedError()
class ExprVisitor(ExprFunctor):
"""
A visitor over Expr.
The default behavior recursively traverses the AST.
"""
def visit_tuple(self, t):
for x in t.fields:
self.visit(x)
def visit_call(self, c):
self.visit(c.op)
for a in c.args:
self.visit(a)
def visit_var(self, v):
pass
def visit_let(self, l):
self.visit(l.var)
self.visit(l.value)
self.visit(l.body)
def visit_function(self, f):
self.visit(f.body)
def visit_if(self, i):
self.visit(i.cond)
self.visit(i.true_branch)
self.visit(i.false_branch)
def visit_global_var(self, gv):
pass
def visit_constructor(self, c):
pass
def visit_op(self, op):
pass
def visit_constant(self, const):
pass
def visit_ref_create(self, r):
self.visit(r.value)
def visit_ref_read(self, r):
self.visit(r.ref)
def visit_ref_write(self, r):
self.visit(r.ref)
self.visit(r.value)
def visit_tuple_getitem(self, t):
self.visit(t.tuple_value)
def visit_match(self, m):
self.visit(m.data)
for c in m.clause:
self.visit(c.rhs)
class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
......
......@@ -16,34 +16,42 @@
# under the License.
import tvm
from tvm import relay
from tvm.relay import ExprFunctor, ExprMutator
from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor
def check_visit(expr):
ef = ExprFunctor()
try:
ef = ExprFunctor()
ef.visit(expr)
assert False
except NotImplementedError:
pass
ev = ExprVisitor()
ev.visit(expr)
em = ExprMutator()
assert em.visit(expr)
def test_constant():
check_visit(relay.const(1.0))
def test_tuple():
t = relay.Tuple([relay.var('x', shape=())])
check_visit(t)
def test_var():
v = relay.var('x', shape=())
check_visit(v)
def test_global():
v = relay.GlobalVar('f')
check_visit(v)
def test_function():
x = relay.var('x', shape=())
y = relay.var('y', shape=())
......@@ -61,12 +69,14 @@ def test_function():
)
check_visit(f)
def test_call():
x = relay.var('x', shape=())
y = relay.var('y', shape=())
call = relay.op.add(x, y)
check_visit(call)
def test_let():
x = relay.var('x', shape=())
value = relay.const(2.0)
......@@ -74,30 +84,43 @@ def test_let():
l = relay.Let(x, value, body)
check_visit(l)
def test_ite():
cond = relay.var('x', shape=(), dtype='bool')
ite = relay.If(cond, cond, cond)
check_visit(ite)
def test_get_item():
t = relay.Tuple([relay.var('x', shape=())])
t = relay.TupleGetItem(t, 0)
check_visit(t)
def test_ref_create():
r = relay.expr.RefCreate(relay.const(1.0))
check_visit(r)
def test_ref_read():
ref = relay.expr.RefCreate(relay.const(1.0))
r = relay.expr.RefRead(ref)
check_visit(r)
def test_ref_write():
ref = relay.expr.RefCreate(relay.const(1.0))
r = relay.expr.RefWrite(ref, relay.const(2.0))
check_visit(r)
def test_memo():
expr = relay.const(1)
for _ in range(100):
expr = expr + expr
check_visit(expr)
if __name__ == "__main__":
test_constant()
test_tuple()
......@@ -110,3 +133,4 @@ if __name__ == "__main__":
test_ref_create()
test_ref_read()
test_ref_write()
test_memo()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment