Commit c8259e3e by 雾雨魔理沙 Committed by Haichen Shen

[Relay] fix checkwellform (#2705)

* do

* address comment
parent cab5af26
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <unordered_set> #include <unordered_set>
namespace tvm { namespace tvm {
...@@ -12,7 +13,7 @@ namespace relay { ...@@ -12,7 +13,7 @@ namespace relay {
//! brief make sure each Var is bind at most once. //! brief make sure each Var is bind at most once.
class WellFormedChecker : private ExprVisitor { class WellFormedChecker : private ExprVisitor, PatternVisitor {
bool well_formed = true; bool well_formed = true;
std::unordered_set<Var, NodeHash, NodeEqual> s; std::unordered_set<Var, NodeHash, NodeEqual> s;
...@@ -39,6 +40,14 @@ class WellFormedChecker : private ExprVisitor { ...@@ -39,6 +40,14 @@ class WellFormedChecker : private ExprVisitor {
CheckWellFormed(f->body); CheckWellFormed(f->body);
} }
void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}
void VisitVar(const Var& v) final {
Check(v);
}
public: public:
bool CheckWellFormed(const Expr& e) { bool CheckWellFormed(const Expr& e) {
this->VisitExpr(e); this->VisitExpr(e);
......
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import well_formed from tvm.relay.ir_pass import well_formed
from tvm.relay.prelude import Prelude
def test_well_formed(): def test_let():
x = relay.Var('x') x = relay.Var("x")
assert well_formed(x) assert well_formed(x)
v = relay.Constant(tvm.nd.array(10)) v = relay.Constant(tvm.nd.array(10))
ty = None ty = None
...@@ -18,7 +19,7 @@ def test_well_formed(): ...@@ -18,7 +19,7 @@ def test_well_formed():
def test_tuple(): def test_tuple():
x = relay.Var('x') x = relay.Var("x")
assert well_formed(x) assert well_formed(x)
v = relay.Constant(tvm.nd.array(10)) v = relay.Constant(tvm.nd.array(10))
let = relay.Let(x, v, x) let = relay.Let(x, v, x)
...@@ -28,5 +29,23 @@ def test_tuple(): ...@@ -28,5 +29,23 @@ def test_tuple():
def test_tuple_get_item(): def test_tuple_get_item():
t = relay.Var('t') t = relay.Var("t")
assert well_formed(relay.TupleGetItem(t, 2)) assert well_formed(relay.TupleGetItem(t, 2))
def test_adt():
mod = relay.Module()
p = Prelude(mod)
x = relay.Var("x")
s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x)
default_case = relay.Clause(relay.PatternVar(x), x)
m0 = relay.Match(p.z(), [default_case])
m1 = relay.Match(p.z(), [s_case, default_case])
assert well_formed(m0)
assert not well_formed(m1)
if __name__ == "__main__":
test_let()
test_tuple()
test_tuple_get_item()
test_adt()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment