Unverified Commit 327891cb by anwang2009 Committed by GitHub

[Relay][Pass] Add submodule extraction pass (#4960)

* rebased

* fix lint
parent 2e913f0b
......@@ -407,3 +407,25 @@ def structural_hash(value):
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
def extract_fused_functions(mod):
"""Pass to extract IRModule of only fused primitive functions.
The ExtractFusedFunctions pass invokes SimplifyInference, FuseOps(3),
and ExtractFusedFunctions in that order
Parameters
----------
mod : tvm.relay.IRModule
Returns
-------
ret : Dict[int, tvm.relay.expr.Function]
A module containing only fused primitive functions
"""
ret_mod = _analysis.ExtractFusedFunctions()(mod)
ret = {}
for hash_, func in ret_mod.functions.items():
ret[hash_] = func
return ret
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file extract_fused_functions.cc
* \brief Apply fusion and extract fused primitive functions from an IRModule
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
class FusedFunctionExtractorWrapper : private ExprVisitor {
public:
explicit FusedFunctionExtractorWrapper(const IRModule& mod) : mod_(mod) {}
IRModule Extract() {
VisitExpr(this->mod_->Lookup("main"));
auto functions = Map<GlobalVar, BaseFunc>();
for (auto pair : this->functions) {
functions.Set(GlobalVar(pair.first), pair.second);
}
this->mod_->functions = functions;
return this->mod_;
}
private:
const IRModule mod_;
// This is not simply Map<GlobalVar, Function> because GlobalVar doesn't
// have the desired equals property
Map<std::string, Function> functions;
void VisitExpr_(const FunctionNode* n) final {
if (n->HasNonzeroAttr(attr::kPrimitive)) {
// Add function to functions, keyed by function hash string
Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs);
size_t hash_ = StructuralHash()(func);
this->functions.Set(std::to_string(hash_), func);
}
ExprVisitor::VisitExpr_(n);
}
};
namespace transform {
Pass ExtractFusedFunctions() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return FusedFunctionExtractorWrapper(m).Extract(); };
auto fused_function_extractor_pass = CreateModulePass(pass_func, 1, "ExtractFusedFunctions", {});
return Sequential({SimplifyInference(), FuseOps(3), fused_function_extractor_pass},
"ExtractFusedFunctions");
}
TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions);
} // namespace transform
} // namespace relay
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test function extraction"""
import tvm
from tvm import relay
from tvm.relay.testing.resnet import get_workload
def get_conv_net():
"""This gets the net for a case described in fuse_ops.cc:
conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add
|
"""
dshape = (1, 1, 5, 1)
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"),
kernel_size=(3, 3),
padding=(1, 1),
channels=1)
x1 = relay.nn.conv2d(y, relay.var("w2"),
kernel_size=(3, 3),
padding=(1, 1),
channels=1)
x2 = relay.nn.conv2d(y, relay.var("w3"),
kernel_size=(3, 3),
padding=(1, 1),
channels=1)
x3 = relay.nn.conv2d(y, relay.var("w4"),
kernel_size=(3, 3),
padding=(1, 1),
channels=1)
z = relay.add(x1, x2)
z = relay.add(x3, z)
return tvm.IRModule.from_expr(z)
def get_conv2d():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NHWC',
kernel_layout='HWIO')
return tvm.IRModule.from_expr(y)
def test_extract_identity():
mod = get_conv2d()
items = relay.analysis.extract_fused_functions(mod)
assert len(items) == 1
mod["main"] = mod["main"].with_attr(
"Primitive", tvm.tir.IntImm("int32", 1))
relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"])
def test_extract_conv_net():
mod = get_conv_net()
items = relay.analysis.extract_fused_functions(mod)
functions = list(items.values())
assert len(functions) == 2
x = functions[0]
y = functions[1]
def is_conv(func):
conv2d = relay.op.op.get("nn.conv2d")
call_node = func.body
return call_node.op == conv2d
def is_conv_add(func):
add = relay.op.op.get("add")
call_node = func.body
maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0])
return call_node.op == add and is_conv(maybe_conv_module["main"])
# Function traversal order isn't obvious, so checking both orders is more consistent
assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y))
def test_extract_resnet():
mod, _params = get_workload()
items = relay.analysis.extract_fused_functions(mod)
assert len(items) == 34
if __name__ == '__main__':
test_extract_identity()
test_extract_conv_net()
test_extract_resnet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment