Unverified Commit 3616ebee by masahi Committed by GitHub

[BYOC] Add example of Composite + Annotate for DNNL fused op (#5272)

* merge change from dev branch

* fix string issue

* bring comanic's change back
parent 4b27cd14
......@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
......@@ -19,19 +19,22 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>
#include <fstream>
#include <sstream>
#include "../../utils.h"
#include "codegen_c.h"
namespace tvm {
namespace relay {
namespace contrib {
using namespace backend;
/*!
* \brief An example codegen that is only used for quick prototyping and testing
* purpose. Only several binary options are covered. Users
......
......@@ -170,41 +170,6 @@ class CodegenCBase {
virtual std::string JIT() = 0;
/*!
* \brief Extract the shape from a Relay tensor type.
*
* \param type The provided type.
*
* \return The extracted shape in a list.
*/
std::vector<int> GetShape(const Type& type) const {
const auto* ttype = type.as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
return shape;
}
/*!
* \brief Check if a call has the provided name.
*
* \param call A Relay call node.
* \param op_name The name of the expected call.
*
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
bool IsOp(const CallNode* call, const std::string& op_name) const {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}
/*!
* \brief A common interface that is used by various external runtime to
* generate the wrapper to invoke external kernels.
*
......
......@@ -30,110 +30,24 @@
#include <tvm/runtime/registry.h>
#include <fstream>
#include <numeric>
#include <sstream>
#include "../../utils.h"
#include "../codegen_c/codegen_c.h"
namespace tvm {
namespace relay {
namespace contrib {
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprVisitor, public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output;
output.name = node->name_hint();
out_.push_back(output);
}
void VisitExpr_(const TupleGetItemNode* op) final {
// Do nothing
}
void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;
std::ostringstream buf_stream;
// Args: ID
std::vector<std::string> args;
// Get the arguments for various DNNL kernels.
if (IsOp(call, "nn.conv2d")) {
decl_stream << "dnnl_conv2d";
args = Conv2d(call);
} else if (IsOp(call, "nn.dense")) {
decl_stream << "dnnl_dense";
args = Dense(call);
} else if (IsOp(call, "nn.relu")) {
decl_stream << "dnnl_relu";
args = Relu(call);
} else if (IsOp(call, "nn.batch_norm")) {
decl_stream << "dnnl_bn";
args = BatchNorm(call);
} else if (IsOp(call, "add")) {
decl_stream << "dnnl_add";
args = Add(call);
} else {
LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
}
// Make function call with input buffers when visiting arguments
bool first = true;
decl_stream << "(";
for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]);
for (auto out : out_) {
if (!first) {
decl_stream << ", ";
}
first = false;
decl_stream << out.name;
}
}
// Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
using namespace backend;
// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
}
std::string JIT(void) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
}
inline size_t GetShape1DSize(const Type& type) {
const auto shape = GetShape(type);
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
}
private:
std::vector<std::string> Conv2d(const CallNode* call) {
std::vector<std::string> Conv2d(const CallNode* call) {
std::vector<std::string> args;
const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
CHECK(conv2d_attr);
......@@ -157,9 +71,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
return args;
}
}
std::vector<std::string> Dense(const CallNode* call) {
std::vector<std::string> Dense(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type());
......@@ -170,9 +84,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
args.push_back(std::to_string(wshape[0]));
return args;
}
}
std::vector<std::string> Relu(const CallNode* call) {
std::vector<std::string> Relu(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());
......@@ -182,9 +96,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
return args;
}
}
std::vector<std::string> BatchNorm(const CallNode* call) {
std::vector<std::string> BatchNorm(const CallNode* call) {
std::vector<std::string> args;
const auto* bn_attr = call->attrs.as<BatchNormAttrs>();
auto ishape = GetShape(call->args[0]->checked_type());
......@@ -198,9 +112,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
args.push_back(std::to_string(bn_attr->epsilon));
return args;
}
}
std::vector<std::string> Add(const CallNode* call) {
std::vector<std::string> Add(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());
......@@ -210,6 +124,203 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
return args;
}
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprVisitor, public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output;
output.name = node->name_hint();
out_.push_back(output);
}
void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));
// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
}
void VisitExpr_(const ConstantNode* cn) final {
Constant constant = GetRef<Constant>(cn);
if (visited_.count(constant)) {
out_.push_back(visited_[constant]);
return;
}
out_.clear();
Output output;
output.name = "const_" + std::to_string(const_idx_++);
output.dtype = "float";
out_.push_back(output);
visited_[constant] = output;
runtime::NDArray array = cn->data;
// Get the number of elements.
int64_t num_elems = 1;
for (auto i : array.Shape()) num_elems *= i;
const auto* type_node = cn->checked_type().as<TensorTypeNode>();
CHECK(type_node);
CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
std::ostringstream buf_stream;
buf_stream << "float* " << output.name << " = (float*)std::malloc(4 * " << num_elems << ");\n";
const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
for (int64_t i = 0; i < num_elems; i++) {
buf_stream << " " << output.name << "[" << i << "] = " << ptr[i] << ";\n";
}
ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
}
void VisitExpr_(const CallNode* call) final {
GenerateBodyOutput ret;
if (const auto* func = call->op.as<FunctionNode>()) {
ret = GenerateCompositeFunctionCall(func, call);
} else {
ret = GenerateOpCall(call);
}
out_.clear();
for (size_t i = 0; i < ret.outputs.size(); ++i) {
buf_decl_.push_back(ret.buffers[i]);
out_.push_back(ret.outputs[i]);
}
ext_func_body.push_back(ret.decl);
}
std::string JIT(void) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
}
private:
struct GenerateBodyOutput {
std::string decl;
std::vector<std::string> buffers;
std::vector<Output> outputs;
};
std::vector<std::string> GetArgumentNames(const CallNode* call) {
std::vector<std::string> arg_names;
for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]);
for (auto out : out_) {
arg_names.push_back(out.name);
}
}
return arg_names;
}
GenerateBodyOutput GenerateOpCall(const CallNode* call) {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expect OpNode, but got " << call->op->GetTypeKey();
using ArgFunType = std::function<std::vector<std::string>(const CallNode*)>;
static const std::map<std::string, std::pair<std::string, ArgFunType>> op_map = {
{"nn.conv2d", {"dnnl_conv2d", Conv2d}},
{"nn.dense", {"dnnl_dense", Dense}},
{"nn.relu", {"dnnl_relu", Relu}},
{"nn.batch_norm", {"dnnl_bn", BatchNorm}},
{"add", {"dnnl_add", Add}},
};
const auto op_name = GetRef<Op>(op_node)->name;
const auto iter = op_map.find(op_name);
if (iter != op_map.end()) {
return GenerateBody(call, iter->second.first, iter->second.second(call));
}
LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
return {};
}
GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
const CallNode* caller) {
const auto pattern_name = callee->GetAttr<runtime::String>(attr::kComposite);
CHECK(pattern_name.defined()) << "Only functions with composite attribute supported";
if (pattern_name == "dnnl.conv2d_bias_relu") {
const auto* conv_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_bias_relu", GetArgumentNames(caller),
Conv2d(conv_call));
} else if (pattern_name == "dnnl.conv2d_relu") {
const auto* conv_call = GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller),
Conv2d(conv_call));
}
LOG(FATAL) << "Unknown composite function:" << pattern_name;
return {};
}
GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name,
const std::vector<std::string>& attribute_args) {
return GenerateBody(root_call, func_name, GetArgumentNames(root_call), attribute_args);
}
GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name,
const std::vector<std::string>& func_args,
const std::vector<std::string>& attribute_args) {
// Make function call with input buffers when visiting arguments
CHECK_GT(func_args.size(), 0);
std::ostringstream decl_stream;
decl_stream << "(" << func_args[0];
for (size_t i = 1; i < func_args.size(); ++i) {
decl_stream << ", " << func_args[i];
}
// Analyze the output buffers
std::vector<Type> out_types;
if (root_call->checked_type()->IsInstance<TupleTypeNode>()) {
auto type_node = root_call->checked_type().as<TupleTypeNode>();
for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
out_types.push_back(field);
}
} else if (root_call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(root_call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(root_call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(root_call->checked_type(), false);
}
GenerateBodyOutput ret;
for (const auto& out_type : out_types) {
this->PrintIndents();
const std::string out = "buf_" + std::to_string(buf_idx_++);
const auto out_size = GetShape1DSize(out_type);
decl_stream << ", " << out;
Output output;
output.name = out;
output.size = out_size;
output.dtype = GetDtypeString(out_type.as<TensorTypeNode>());
output.need_copy = true;
ret.buffers.push_back("float* " + out + " = (float*)std::malloc(4 * " +
std::to_string(out_size) + ");");
ret.outputs.push_back(output);
}
// Attach attribute arguments
for (size_t i = 0; i < attribute_args.size(); ++i) {
decl_stream << ", " << attribute_args[i];
}
decl_stream << ");";
ret.decl = func_name + decl_stream.str();
return ret;
}
/*! \brief The id of the external dnnl ext_func. */
......@@ -219,6 +330,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
* output to a buffer that may be consumed by other kernels.
*/
int buf_idx_{0};
/*! \brief The index of global constants. */
int const_idx_ = 0;
/*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
Array<Var> ext_func_args_;
/*! \brief statement of the function that will be compiled using DNNL kernels. */
......@@ -227,6 +340,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
std::vector<std::string> buf_decl_;
/*! \brief The name of the the outputs. */
std::vector<Output> out_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
};
/*!
......
......@@ -25,18 +25,19 @@
#define TVM_RELAY_BACKEND_UTILS_H_
#include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/type.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/te/operation.h>
#include <tvm/tir/ir_pass.h>
#include <typeinfo>
#include <string>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace tvm {
namespace relay {
......@@ -59,7 +60,7 @@ inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
*/
template <typename R, typename... Args>
inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) {
auto *pf = GetPackedFunc(func_name);
auto* pf = GetPackedFunc(func_name);
CHECK(pf != nullptr) << "can not find packed function";
return runtime::TypedPackedFunc<R(Args...)>(*pf);
}
......@@ -90,9 +91,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
* \param params params dict
* \return relay::Function
*/
inline relay::Function
BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
inline relay::Function BindParamsByName(
relay::Function func, const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
......@@ -122,8 +122,64 @@ BindParamsByName(relay::Function func,
return ret;
}
/*!
* \brief Extract the shape from a Relay tensor type.
* \param type The provided type.
* \return The extracted shape in a list.
*/
inline std::vector<int> GetShape(const Type& type) {
const auto* ttype = type.as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
return shape;
}
/*!
* \brief Check if a call has the provided name.
* \param call A Relay call node.
* \param op_name The name of the expected call.
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
inline bool IsOp(const CallNode* call, const std::string& op_name) {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}
/*!
* \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d))
* \param call A Relay call node. Typically nn.relu when called the first time.
* \param depth The number of calls before the root op, counting from current_call.
* \param expected_op_names The names of ops in this fused call. Example: {"nn.conv2d", "add",
* "nn.relu"}
* \return A CallNode corresponding to the root op, whose name is expected_op_names[0]
*/
inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
const std::vector<std::string>& expected_op_names) {
CHECK(current_call && depth >= 0 && static_cast<size_t>(depth) < expected_op_names.size() &&
IsOp(current_call, expected_op_names[depth]));
if (depth == 0) {
return current_call;
}
CHECK_GT(current_call->args.size(), 0);
const auto* next_call = current_call->args[0].as<CallNode>();
return GetRootCall(next_call, depth - 1, expected_op_names);
}
} // namespace backend
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_UTILS_H_
......@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
......@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen.
pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
......@@ -52,10 +52,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) {
std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle));
}
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_) {
void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
......@@ -65,21 +64,15 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_};
if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_};
memory::dims conv2d_bias_tz = {p_O_};
memory::dims conv2d_dst_tz = {p_N_, p_O_,
(p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_};
memory::dims conv2d_strides = {p_Sh_, p_Sw_};
memory::dims conv2d_padding = {p_Ph_, p_Pw_};
std::vector<float> conv2d_bias(p_O_, 0);
auto user_src_memory =
memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng,
weights);
auto conv2d_user_bias_memory =
memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data());
auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_weights_memory =
memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, weights);
auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);
auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
......@@ -87,10 +80,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw);
auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct,
conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md,
conv2d_strides, conv2d_padding, conv2d_padding);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng);
prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding, conv2d_padding);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng);
auto conv2d_src_memory = user_src_memory;
auto conv2d_weights_memory = user_weights_memory;
......@@ -105,6 +97,42 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
read_from_dnnl_memory(out, conv2d_dst_memory);
}
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Kw_, int p_Sh_, int p_Sw_) {
primitive_attr attr;
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr);
}
primitive_attr create_attr_with_relu_post_op() {
post_ops ops;
ops.append_eltwise(1.f, algorithm::eltwise_relu, 0.f, 0.f);
primitive_attr attr;
attr.set_post_ops(ops);
return attr;
}
extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) {
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op());
}
extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out,
int p_N_, int p_C_, int p_H_, int p_W_, int p_O_,
int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_) {
return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph_,
p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op());
}
extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_,
int p_I_, int p_O_) {
using tag = memory::format_tag;
......@@ -169,8 +197,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory);
}
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_N_, int p_C_,
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) {
using tag = memory::format_tag;
using dt = memory::data_type;
......
......@@ -38,14 +38,25 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias,
float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_,
int p_Pw_, int p_Kh_, int p_Kw_, int p_Sh_,
int p_Sw_);
extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_,
int p_O_);
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
int p_e_);
float* variance, float* out, float* new_mean, float* new_variance,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_);
......
......@@ -27,11 +27,10 @@ from tvm import relay
from tvm import runtime
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay import transform
from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.runtime import container
from tvm.relay.build_module import bind_params_by_name
# Leverage the pass manager to write a simple white list based annotator
......@@ -456,7 +455,7 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod["main"] = bind_params_by_name(mod["main"], params)
mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
......@@ -663,7 +662,7 @@ def test_constant_propagation():
add = x + y
log = relay.log(add)
f = relay.Function([x, y], log)
f = relay.build_module.bind_params_by_name(f, {"x": tvm.nd.array(ones)})
f = bind_params_by_name(f, {"x": tvm.nd.array(ones)})
mod = tvm.IRModule()
mod["main"] = f
mod = WhiteListAnnotator(["add"], "ccompiler")(mod)
......@@ -852,6 +851,128 @@ def test_mixed_single_multiple_outputs():
partitioned = transform.PartitionGraph()(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_dnnl_fuse():
def make_pattern(with_bias=True):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight")
bias = relay.var("bias")
conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=8, padding=(1, 1))
if with_bias:
conv_out = relay.add(conv, bias)
else:
conv_out = conv
return relay.nn.relu(conv_out)
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
def get_blocks(prefix, data, in_channel, out_channel,
include_bn=True, include_sigmoid=False):
weight = relay.var(prefix + "weight")
bn_gamma = relay.var(prefix + "bn_gamma")
bn_beta = relay.var(prefix + "bn_beta")
bn_mmean = relay.var(prefix + "bn_mean")
bn_mvar = relay.var(prefix + "bn_var")
layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=out_channel, padding=(1, 1))
if include_bn:
bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta,
bn_mmean, bn_mvar)
layer = bn_output[0]
if include_sigmoid:
# dummy layer to prevent pattern detection
layer = relay.sigmoid(layer)
layer = relay.nn.relu(layer)
return layer
def get_net(include_bn=True, include_sigmoid=False):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
block1 = get_blocks("block1_", data, 3, 8, include_bn, include_sigmoid)
# The second block is always conv + relu, to make it more interesting
block2 = get_blocks("block2_", block1, 8, 8, False, include_sigmoid)
return relay.Function(relay.analysis.free_vars(block2), block2)
def get_partitoned_mod(mod, params, pattern_table):
# This is required for constant folding
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = transform.Sequential([
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
])
composite_partition = transform.Sequential([
remove_bn_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
transform.PartitionGraph()
])
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
return composite_partition(mod)
def test_detect_pattern(pattern_table, include_bn, include_sigmoid,
num_expected_partition):
net = get_net(include_bn, include_sigmoid)
mod, params = tvm.relay.testing.create_workload(net)
mod = get_partitoned_mod(mod, params, pattern_table)
assert(len(mod.functions) - 1 == num_expected_partition) # -1 for main
def test_partition():
# conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
test_detect_pattern([conv2d_bias_relu_pat], True, False, 3)
# conv + bn + relu, conv + relu -> conv, bias, relu, and fused conv_relu
test_detect_pattern([conv2d_relu_pat], True, False, 4)
# conv + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, False, 2)
# conv + relu, conv + relu -> two fused conv_relu
test_detect_pattern([conv2d_relu_pat], False, False, 2)
# conv + relu, conv + relu -> no fusion, 4 partition each with a single op
test_detect_pattern([conv2d_bias_relu_pat], False, False, 4)
# conv + bn + sigmoid + relu, conv + sigmoid + relu -> no fusion
test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, 5)
def test_partition_mobilenet():
mod, params = relay.testing.mobilenet.get_workload()
mod = get_partitoned_mod(mod, params, dnnl_patterns)
# 27 fused conv + bn + relu and one dense
assert(len(mod.functions) - 1 == 28) # -1 for main
def test_exec(mod, params, ref_mod, ref_params, out_shape):
ishape = (1, 3, 224, 224)
i_data = np.random.randn(*ishape).astype(np.float32)
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
ref_res = ref_ex.evaluate()(i_data, **ref_params)
compile_engine.get().clear()
mod = get_partitoned_mod(mod, params, dnnl_patterns)
check_result(mod, {"data": i_data},
out_shape, ref_res.asnumpy(), tol=1e-5, params=params)
test_partition()
test_partition_mobilenet()
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
net = get_net()
mod, params = tvm.relay.testing.create_workload(net)
ref_mod, ref_params = tvm.relay.testing.create_workload(net)
test_exec(mod, params, ref_mod, ref_params, (1, 8, 224, 224))
# exec test on mobilenet is not possible due to manually inlined constants
# mod, params = relay.testing.mobilenet.get_workload()
# ref_mod, ref_params = relay.testing.mobilenet.get_workload()
# test_exec(mod, params, ref_mod, ref_params, (1, 1000))
if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
......@@ -865,3 +986,4 @@ if __name__ == "__main__":
test_constant_propagation()
test_multiple_outputs()
test_mixed_single_multiple_outputs()
test_dnnl_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment