Unverified Commit c9a2f3da by Tianqi Chen Committed by GitHub

[RELAY] Pass infra cleanup (#3336)

parent d6c4aba8
......@@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode {
v->Visit("required", &required);
}
TVM_DLL static PassInfo make(int opt_level, std::string name,
TVM_DLL static PassInfo make(int opt_level,
std::string name,
tvm::Array<tvm::Expr> required);
static constexpr const char* _type_key = "relay.PassInfo";
......
......@@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_API("relay._transform.CreateModulePass")
.set_body_typed(CreateModulePass);
TVM_REGISTER_API("relay._transform.MakeModulePass")
.set_body_typed(ModulePassNode::make);
TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......@@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_API("relay._transform.CreateFunctionPass")
.set_body_typed(CreateFunctionPass);
TVM_REGISTER_API("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
......
......@@ -259,6 +259,12 @@ def test_function_pass():
test_pass_run()
def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
assert info.name == "xyz"
def test_sequential_pass():
shape = (10, )
dtype = 'float32'
......@@ -449,3 +455,4 @@ if __name__ == "__main__":
test_function_pass()
test_sequential_pass()
test_sequential_with_scoping()
test_pass_info()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment