Unverified Commit 3616ebee by masahi Committed by GitHub

[BYOC] Add example of Composite + Annotate for DNNL fused op (#5272)

* merge change from dev branch

* fix string issue

* bring comanic's change back
parent 4b27cd14
...@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True): ...@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d") _register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense") _register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu") _register_external_op_helper("nn.relu")
_register_external_op_helper("add") _register_external_op_helper("add")
_register_external_op_helper("subtract") _register_external_op_helper("subtract")
_register_external_op_helper("multiply") _register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
...@@ -19,19 +19,22 @@ ...@@ -19,19 +19,22 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "../../utils.h"
#include "codegen_c.h" #include "codegen_c.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace contrib { namespace contrib {
using namespace backend;
/*! /*!
* \brief An example codegen that is only used for quick prototyping and testing * \brief An example codegen that is only used for quick prototyping and testing
* purpose. Only several binary options are covered. Users * purpose. Only several binary options are covered. Users
......
...@@ -170,41 +170,6 @@ class CodegenCBase { ...@@ -170,41 +170,6 @@ class CodegenCBase {
virtual std::string JIT() = 0; virtual std::string JIT() = 0;
/*! /*!
* \brief Extract the shape from a Relay tensor type.
*
* \param type The provided type.
*
* \return The extracted shape in a list.
*/
std::vector<int> GetShape(const Type& type) const {
const auto* ttype = type.as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
return shape;
}
/*!
* \brief Check if a call has the provided name.
*
* \param call A Relay call node.
* \param op_name The name of the expected call.
*
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
bool IsOp(const CallNode* call, const std::string& op_name) const {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}
/*!
* \brief A common interface that is used by various external runtime to * \brief A common interface that is used by various external runtime to
* generate the wrapper to invoke external kernels. * generate the wrapper to invoke external kernels.
* *
......
...@@ -25,18 +25,19 @@ ...@@ -25,18 +25,19 @@
#define TVM_RELAY_BACKEND_UTILS_H_ #define TVM_RELAY_BACKEND_UTILS_H_
#include <dmlc/json.h> #include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/driver/driver_api.h> #include <tvm/relay/type.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/tir/ir_pass.h>
#include <typeinfo>
#include <string> #include <string>
#include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -59,7 +60,7 @@ inline const PackedFunc* GetPackedFunc(const std::string& func_name) { ...@@ -59,7 +60,7 @@ inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
*/ */
template <typename R, typename... Args> template <typename R, typename... Args>
inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) { inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) {
auto *pf = GetPackedFunc(func_name); auto* pf = GetPackedFunc(func_name);
CHECK(pf != nullptr) << "can not find packed function"; CHECK(pf != nullptr) << "can not find packed function";
return runtime::TypedPackedFunc<R(Args...)>(*pf); return runtime::TypedPackedFunc<R(Args...)>(*pf);
} }
...@@ -90,9 +91,8 @@ inline std::string DType2String(const tvm::DataType dtype) { ...@@ -90,9 +91,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
* \param params params dict * \param params params dict
* \return relay::Function * \return relay::Function
*/ */
inline relay::Function inline relay::Function BindParamsByName(
BindParamsByName(relay::Function func, relay::Function func, const std::unordered_map<std::string, runtime::NDArray>& params) {
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict; std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var; std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) { for (auto arg : func->params) {
...@@ -122,8 +122,64 @@ BindParamsByName(relay::Function func, ...@@ -122,8 +122,64 @@ BindParamsByName(relay::Function func,
return ret; return ret;
} }
/*!
* \brief Extract the shape from a Relay tensor type.
* \param type The provided type.
* \return The extracted shape in a list.
*/
inline std::vector<int> GetShape(const Type& type) {
const auto* ttype = type.as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
return shape;
}
/*!
* \brief Check if a call has the provided name.
* \param call A Relay call node.
* \param op_name The name of the expected call.
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
inline bool IsOp(const CallNode* call, const std::string& op_name) {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}
/*!
* \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d))
* \param call A Relay call node. Typically nn.relu when called the first time.
* \param depth The number of calls before the root op, counting from current_call.
* \param expected_op_names The names of ops in this fused call. Example: {"nn.conv2d", "add",
* "nn.relu"}
* \return A CallNode corresponding to the root op, whose name is expected_op_names[0]
*/
inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
const std::vector<std::string>& expected_op_names) {
CHECK(current_call && depth >= 0 && static_cast<size_t>(depth) < expected_op_names.size() &&
IsOp(current_call, expected_op_names[depth]));
if (depth == 0) {
return current_call;
}
CHECK_GT(current_call->args.size(), 0);
const auto* next_call = current_call->args[0].as<CallNode>();
return GetRootCall(next_call, depth - 1, expected_op_names);
}
} // namespace backend } // namespace backend
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_BACKEND_UTILS_H_ #endif // TVM_RELAY_BACKEND_UTILS_H_
...@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this // Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation // pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op` // pass. This is because memory allocation pass will insert `invoke_tvm_op`
...@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen. // external codegen.
pass_seqs.push_back(transform::Inline()); pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions. // Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
...@@ -52,10 +52,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) { ...@@ -52,10 +52,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) {
std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle)); std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle));
} }
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr) {
int p_Sh_, int p_Sw_) {
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
engine eng(engine::kind::cpu, 0); engine eng(engine::kind::cpu, 0);
...@@ -65,21 +64,15 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -65,21 +64,15 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_}; memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_};
if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_}; if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_};
memory::dims conv2d_bias_tz = {p_O_}; memory::dims conv2d_bias_tz = {p_O_};
memory::dims conv2d_dst_tz = {p_N_, p_O_, memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_}; (p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_};
memory::dims conv2d_strides = {p_Sh_, p_Sw_}; memory::dims conv2d_strides = {p_Sh_, p_Sw_};
memory::dims conv2d_padding = {p_Ph_, p_Pw_}; memory::dims conv2d_padding = {p_Ph_, p_Pw_};
std::vector<float> conv2d_bias(p_O_, 0); auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_weights_memory =
auto user_src_memory = memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, weights);
memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);
auto user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng,
weights);
auto conv2d_user_bias_memory =
memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data());
auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
...@@ -87,10 +80,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -87,10 +80,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw); auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw);
auto conv2d_desc = convolution_forward::desc( auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct, prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md, conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding, conv2d_padding);
conv2d_strides, conv2d_padding, conv2d_padding); auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng);
auto conv2d_src_memory = user_src_memory; auto conv2d_src_memory = user_src_memory;
auto conv2d_weights_memory = user_weights_memory; auto conv2d_weights_memory = user_weights_memory;
...@@ -105,6 +97,42 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -105,6 +97,42 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
read_from_dnnl_memory(out, conv2d_dst_memory); read_from_dnnl_memory(out, conv2d_dst_memory);
} }
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Kw_, int p_Sh_, int p_Sw_) {
primitive_attr attr;
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr);
}
primitive_attr create_attr_with_relu_post_op() {
post_ops ops;
ops.append_eltwise(1.f, algorithm::eltwise_relu, 0.f, 0.f);
primitive_attr attr;
attr.set_post_ops(ops);
return attr;
}
extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) {
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op());
}
extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out,
int p_N_, int p_C_, int p_H_, int p_W_, int p_O_,
int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_) {
return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph_,
p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op());
}
extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_,
int p_I_, int p_O_) { int p_I_, int p_O_) {
using tag = memory::format_tag; using tag = memory::format_tag;
...@@ -169,8 +197,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, ...@@ -169,8 +197,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory); read_from_dnnl_memory(out, dst_memory);
} }
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
float* variance, float* out, int p_N_, int p_C_, float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) { int p_H_, int p_W_, int p_E_) {
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
......
...@@ -38,14 +38,25 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int ...@@ -38,14 +38,25 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_); int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias,
float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_,
int p_Pw_, int p_Kh_, int p_Kw_, int p_Sh_,
int p_Sw_);
extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_,
int p_O_); int p_O_);
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_, float* variance, float* out, float* new_mean, float* new_variance,
int p_e_); int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_); int p_h_, int p_w_);
......
...@@ -27,11 +27,10 @@ from tvm import relay ...@@ -27,11 +27,10 @@ from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay import transform
from tvm.relay.backend import compile_engine from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.runtime import container from tvm.relay.build_module import bind_params_by_name
# Leverage the pass manager to write a simple white list based annotator # Leverage the pass manager to write a simple white list based annotator
...@@ -456,7 +455,7 @@ def test_extern_dnnl_mobilenet(): ...@@ -456,7 +455,7 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload( mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32') batch_size=1, dtype='float32')
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) mod["main"] = bind_params_by_name(mod["main"], params)
mod = transform.AnnotateTarget(["dnnl"])(mod) mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod) mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
...@@ -663,7 +662,7 @@ def test_constant_propagation(): ...@@ -663,7 +662,7 @@ def test_constant_propagation():
add = x + y add = x + y
log = relay.log(add) log = relay.log(add)
f = relay.Function([x, y], log) f = relay.Function([x, y], log)
f = relay.build_module.bind_params_by_name(f, {"x": tvm.nd.array(ones)}) f = bind_params_by_name(f, {"x": tvm.nd.array(ones)})
mod = tvm.IRModule() mod = tvm.IRModule()
mod["main"] = f mod["main"] = f
mod = WhiteListAnnotator(["add"], "ccompiler")(mod) mod = WhiteListAnnotator(["add"], "ccompiler")(mod)
...@@ -852,6 +851,128 @@ def test_mixed_single_multiple_outputs(): ...@@ -852,6 +851,128 @@ def test_mixed_single_multiple_outputs():
partitioned = transform.PartitionGraph()(mod) partitioned = transform.PartitionGraph()(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_dnnl_fuse():
def make_pattern(with_bias=True):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight")
bias = relay.var("bias")
conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=8, padding=(1, 1))
if with_bias:
conv_out = relay.add(conv, bias)
else:
conv_out = conv
return relay.nn.relu(conv_out)
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
def get_blocks(prefix, data, in_channel, out_channel,
include_bn=True, include_sigmoid=False):
weight = relay.var(prefix + "weight")
bn_gamma = relay.var(prefix + "bn_gamma")
bn_beta = relay.var(prefix + "bn_beta")
bn_mmean = relay.var(prefix + "bn_mean")
bn_mvar = relay.var(prefix + "bn_var")
layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=out_channel, padding=(1, 1))
if include_bn:
bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta,
bn_mmean, bn_mvar)
layer = bn_output[0]
if include_sigmoid:
# dummy layer to prevent pattern detection
layer = relay.sigmoid(layer)
layer = relay.nn.relu(layer)
return layer
def get_net(include_bn=True, include_sigmoid=False):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
block1 = get_blocks("block1_", data, 3, 8, include_bn, include_sigmoid)
# The second block is always conv + relu, to make it more interesting
block2 = get_blocks("block2_", block1, 8, 8, False, include_sigmoid)
return relay.Function(relay.analysis.free_vars(block2), block2)
def get_partitoned_mod(mod, params, pattern_table):
# This is required for constant folding
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = transform.Sequential([
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
])
composite_partition = transform.Sequential([
remove_bn_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
transform.PartitionGraph()
])
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
return composite_partition(mod)
def test_detect_pattern(pattern_table, include_bn, include_sigmoid,
num_expected_partition):
net = get_net(include_bn, include_sigmoid)
mod, params = tvm.relay.testing.create_workload(net)
mod = get_partitoned_mod(mod, params, pattern_table)
assert(len(mod.functions) - 1 == num_expected_partition) # -1 for main
def test_partition():
# conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
test_detect_pattern([conv2d_bias_relu_pat], True, False, 3)
# conv + bn + relu, conv + relu -> conv, bias, relu, and fused conv_relu
test_detect_pattern([conv2d_relu_pat], True, False, 4)
# conv + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, False, 2)
# conv + relu, conv + relu -> two fused conv_relu
test_detect_pattern([conv2d_relu_pat], False, False, 2)
# conv + relu, conv + relu -> no fusion, 4 partition each with a single op
test_detect_pattern([conv2d_bias_relu_pat], False, False, 4)
# conv + bn + sigmoid + relu, conv + sigmoid + relu -> no fusion
test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, 5)
def test_partition_mobilenet():
mod, params = relay.testing.mobilenet.get_workload()
mod = get_partitoned_mod(mod, params, dnnl_patterns)
# 27 fused conv + bn + relu and one dense
assert(len(mod.functions) - 1 == 28) # -1 for main
def test_exec(mod, params, ref_mod, ref_params, out_shape):
ishape = (1, 3, 224, 224)
i_data = np.random.randn(*ishape).astype(np.float32)
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
ref_res = ref_ex.evaluate()(i_data, **ref_params)
compile_engine.get().clear()
mod = get_partitoned_mod(mod, params, dnnl_patterns)
check_result(mod, {"data": i_data},
out_shape, ref_res.asnumpy(), tol=1e-5, params=params)
test_partition()
test_partition_mobilenet()
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
net = get_net()
mod, params = tvm.relay.testing.create_workload(net)
ref_mod, ref_params = tvm.relay.testing.create_workload(net)
test_exec(mod, params, ref_mod, ref_params, (1, 8, 224, 224))
# exec test on mobilenet is not possible due to manually inlined constants
# mod, params = relay.testing.mobilenet.get_workload()
# ref_mod, ref_params = relay.testing.mobilenet.get_workload()
# test_exec(mod, params, ref_mod, ref_params, (1, 1000))
if __name__ == "__main__": if __name__ == "__main__":
test_multi_node_compiler() test_multi_node_compiler()
test_extern_ccompiler_single_op() test_extern_ccompiler_single_op()
...@@ -865,3 +986,4 @@ if __name__ == "__main__": ...@@ -865,3 +986,4 @@ if __name__ == "__main__":
test_constant_propagation() test_constant_propagation()
test_multiple_outputs() test_multiple_outputs()
test_mixed_single_multiple_outputs() test_mixed_single_multiple_outputs()
test_dnnl_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment