Unverified Commit 3616ebee by masahi Committed by GitHub

[BYOC] Add example of Composite + Annotate for DNNL fused op (#5272)

* merge change from dev branch

* fix string issue

* bring comanic's change back
parent 4b27cd14
...@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True): ...@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d") _register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense") _register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu") _register_external_op_helper("nn.relu")
_register_external_op_helper("add") _register_external_op_helper("add")
_register_external_op_helper("subtract") _register_external_op_helper("subtract")
_register_external_op_helper("multiply") _register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
...@@ -19,19 +19,22 @@ ...@@ -19,19 +19,22 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "../../utils.h"
#include "codegen_c.h" #include "codegen_c.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace contrib { namespace contrib {
using namespace backend;
/*! /*!
* \brief An example codegen that is only used for quick prototyping and testing * \brief An example codegen that is only used for quick prototyping and testing
* purpose. Only several binary options are covered. Users * purpose. Only several binary options are covered. Users
......
...@@ -170,41 +170,6 @@ class CodegenCBase { ...@@ -170,41 +170,6 @@ class CodegenCBase {
virtual std::string JIT() = 0; virtual std::string JIT() = 0;
/*! /*!
* \brief Extract the shape from a Relay tensor type.
*
* \param type The provided type.
*
* \return The extracted shape in a list.
*/
std::vector<int> GetShape(const Type& type) const {
const auto* ttype = type.as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
return shape;
}
/*!
* \brief Check if a call has the provided name.
*
* \param call A Relay call node.
* \param op_name The name of the expected call.
*
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
bool IsOp(const CallNode* call, const std::string& op_name) const {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}
/*!
* \brief A common interface that is used by various external runtime to * \brief A common interface that is used by various external runtime to
* generate the wrapper to invoke external kernels. * generate the wrapper to invoke external kernels.
* *
......
...@@ -30,14 +30,102 @@ ...@@ -30,14 +30,102 @@
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <fstream> #include <fstream>
#include <numeric>
#include <sstream> #include <sstream>
#include "../../utils.h"
#include "../codegen_c/codegen_c.h" #include "../codegen_c/codegen_c.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace contrib { namespace contrib {
using namespace backend;
inline size_t GetShape1DSize(const Type& type) {
const auto shape = GetShape(type);
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
}
std::vector<std::string> Conv2d(const CallNode* call) {
std::vector<std::string> args;
const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
CHECK(conv2d_attr);
auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type());
// Args: N, C, H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}
// Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
args.push_back(std::to_string(wshape[0]));
args.push_back(std::to_string(conv2d_attr->groups));
args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
args.push_back(std::to_string(wshape[2]));
args.push_back(std::to_string(wshape[3]));
args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
return args;
}
std::vector<std::string> Dense(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type());
// Args: N, C, O
args.push_back(std::to_string(ishape[0]));
args.push_back(std::to_string(ishape[1]));
args.push_back(std::to_string(wshape[0]));
return args;
}
std::vector<std::string> Relu(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());
// Args: N, C, H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}
return args;
}
std::vector<std::string> BatchNorm(const CallNode* call) {
std::vector<std::string> args;
const auto* bn_attr = call->attrs.as<BatchNormAttrs>();
auto ishape = GetShape(call->args[0]->checked_type());
// Args: N, C, H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}
// Args: epsilon
args.push_back(std::to_string(bn_attr->epsilon));
return args;
}
std::vector<std::string> Add(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());
// Args: H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}
return args;
}
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement // TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement. // all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprVisitor, public CodegenCBase { class CodegenDNNL : public ExprVisitor, public CodegenCBase {
...@@ -53,79 +141,64 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -53,79 +141,64 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
void VisitExpr_(const TupleGetItemNode* op) final { void VisitExpr_(const TupleGetItemNode* op) final {
// Do nothing VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));
// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
} }
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const ConstantNode* cn) final {
std::ostringstream decl_stream; Constant constant = GetRef<Constant>(cn);
std::ostringstream buf_stream; if (visited_.count(constant)) {
// Args: ID out_.push_back(visited_[constant]);
std::vector<std::string> args; return;
// Get the arguments for various DNNL kernels.
if (IsOp(call, "nn.conv2d")) {
decl_stream << "dnnl_conv2d";
args = Conv2d(call);
} else if (IsOp(call, "nn.dense")) {
decl_stream << "dnnl_dense";
args = Dense(call);
} else if (IsOp(call, "nn.relu")) {
decl_stream << "dnnl_relu";
args = Relu(call);
} else if (IsOp(call, "nn.batch_norm")) {
decl_stream << "dnnl_bn";
args = BatchNorm(call);
} else if (IsOp(call, "add")) {
decl_stream << "dnnl_add";
args = Add(call);
} else {
LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
} }
// Make function call with input buffers when visiting arguments out_.clear();
bool first = true; Output output;
decl_stream << "("; output.name = "const_" + std::to_string(const_idx_++);
for (size_t i = 0; i < call->args.size(); ++i) { output.dtype = "float";
VisitExpr(call->args[i]); out_.push_back(output);
for (auto out : out_) { visited_[constant] = output;
if (!first) {
decl_stream << ", "; runtime::NDArray array = cn->data;
}
first = false;
decl_stream << out.name;
}
}
// Analyze the output buffer // Get the number of elements.
auto type_node = call->checked_type().as<TensorTypeNode>(); int64_t num_elems = 1;
for (auto i : array.Shape()) num_elems *= i;
const auto* type_node = cn->checked_type().as<TensorTypeNode>();
CHECK(type_node); CHECK(type_node);
const auto& dtype = GetDtypeString(type_node); CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type()); std::ostringstream buf_stream;
int out_size = 1; buf_stream << "float* " << output.name << " = (float*)std::malloc(4 * " << num_elems << ");\n";
for (size_t i = 0; i < out_shape.size(); ++i) { const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
out_size *= out_shape[i]; for (int64_t i = 0; i < num_elems; i++) {
buf_stream << " " << output.name << "[" << i << "] = " << ptr[i] << ";\n";
} }
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Attach attribute arguments ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
for (size_t i = 0; i < args.size(); ++i) { }
decl_stream << ", " << args[i];
void VisitExpr_(const CallNode* call) final {
GenerateBodyOutput ret;
if (const auto* func = call->op.as<FunctionNode>()) {
ret = GenerateCompositeFunctionCall(func, call);
} else {
ret = GenerateOpCall(call);
} }
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
// Update output buffer
out_.clear(); out_.clear();
Output output; for (size_t i = 0; i < ret.outputs.size(); ++i) {
output.name = out; buf_decl_.push_back(ret.buffers[i]);
output.dtype = dtype; out_.push_back(ret.outputs[i]);
output.need_copy = true; }
output.size = out_size; ext_func_body.push_back(ret.decl);
out_.push_back(output);
} }
std::string JIT(void) { std::string JIT(void) {
...@@ -133,83 +206,121 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -133,83 +206,121 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
private: private:
std::vector<std::string> Conv2d(const CallNode* call) { struct GenerateBodyOutput {
std::vector<std::string> args; std::string decl;
const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>(); std::vector<std::string> buffers;
CHECK(conv2d_attr); std::vector<Output> outputs;
};
auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type()); std::vector<std::string> GetArgumentNames(const CallNode* call) {
std::vector<std::string> arg_names;
// Args: N, C, H, W for (size_t i = 0; i < call->args.size(); ++i) {
for (auto s : ishape) { VisitExpr(call->args[i]);
args.push_back(std::to_string(s)); for (auto out : out_) {
arg_names.push_back(out.name);
}
} }
return arg_names;
// Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
args.push_back(std::to_string(wshape[0]));
args.push_back(std::to_string(conv2d_attr->groups));
args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
args.push_back(std::to_string(wshape[2]));
args.push_back(std::to_string(wshape[3]));
args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
return args;
} }
std::vector<std::string> Dense(const CallNode* call) { GenerateBodyOutput GenerateOpCall(const CallNode* call) {
std::vector<std::string> args; const auto* op_node = call->op.as<OpNode>();
auto ishape = GetShape(call->args[0]->checked_type()); CHECK(op_node) << "Expect OpNode, but got " << call->op->GetTypeKey();
auto wshape = GetShape(call->args[1]->checked_type());
using ArgFunType = std::function<std::vector<std::string>(const CallNode*)>;
// Args: N, C, O static const std::map<std::string, std::pair<std::string, ArgFunType>> op_map = {
args.push_back(std::to_string(ishape[0])); {"nn.conv2d", {"dnnl_conv2d", Conv2d}},
args.push_back(std::to_string(ishape[1])); {"nn.dense", {"dnnl_dense", Dense}},
args.push_back(std::to_string(wshape[0])); {"nn.relu", {"dnnl_relu", Relu}},
{"nn.batch_norm", {"dnnl_bn", BatchNorm}},
{"add", {"dnnl_add", Add}},
};
const auto op_name = GetRef<Op>(op_node)->name;
const auto iter = op_map.find(op_name);
if (iter != op_map.end()) {
return GenerateBody(call, iter->second.first, iter->second.second(call));
}
return args; LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
return {};
} }
std::vector<std::string> Relu(const CallNode* call) { GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
std::vector<std::string> args; const CallNode* caller) {
auto ishape = GetShape(call->args[0]->checked_type()); const auto pattern_name = callee->GetAttr<runtime::String>(attr::kComposite);
CHECK(pattern_name.defined()) << "Only functions with composite attribute supported";
// Args: N, C, H, W
for (auto s : ishape) { if (pattern_name == "dnnl.conv2d_bias_relu") {
args.push_back(std::to_string(s)); const auto* conv_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_bias_relu", GetArgumentNames(caller),
Conv2d(conv_call));
} else if (pattern_name == "dnnl.conv2d_relu") {
const auto* conv_call = GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller),
Conv2d(conv_call));
} }
return args; LOG(FATAL) << "Unknown composite function:" << pattern_name;
return {};
} }
std::vector<std::string> BatchNorm(const CallNode* call) { GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name,
std::vector<std::string> args; const std::vector<std::string>& attribute_args) {
const auto* bn_attr = call->attrs.as<BatchNormAttrs>(); return GenerateBody(root_call, func_name, GetArgumentNames(root_call), attribute_args);
auto ishape = GetShape(call->args[0]->checked_type()); }
// Args: N, C, H, W GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name,
for (auto s : ishape) { const std::vector<std::string>& func_args,
args.push_back(std::to_string(s)); const std::vector<std::string>& attribute_args) {
// Make function call with input buffers when visiting arguments
CHECK_GT(func_args.size(), 0);
std::ostringstream decl_stream;
decl_stream << "(" << func_args[0];
for (size_t i = 1; i < func_args.size(); ++i) {
decl_stream << ", " << func_args[i];
} }
// Args: epsilon // Analyze the output buffers
args.push_back(std::to_string(bn_attr->epsilon)); std::vector<Type> out_types;
if (root_call->checked_type()->IsInstance<TupleTypeNode>()) {
return args; auto type_node = root_call->checked_type().as<TupleTypeNode>();
} for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
std::vector<std::string> Add(const CallNode* call) { out_types.push_back(field);
std::vector<std::string> args; }
auto ishape = GetShape(call->args[0]->checked_type()); } else if (root_call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(root_call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(root_call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(root_call->checked_type(), false);
}
// Args: H, W GenerateBodyOutput ret;
for (auto s : ishape) { for (const auto& out_type : out_types) {
args.push_back(std::to_string(s)); this->PrintIndents();
const std::string out = "buf_" + std::to_string(buf_idx_++);
const auto out_size = GetShape1DSize(out_type);
decl_stream << ", " << out;
Output output;
output.name = out;
output.size = out_size;
output.dtype = GetDtypeString(out_type.as<TensorTypeNode>());
output.need_copy = true;
ret.buffers.push_back("float* " + out + " = (float*)std::malloc(4 * " +
std::to_string(out_size) + ");");
ret.outputs.push_back(output);
} }
return args; // Attach attribute arguments
for (size_t i = 0; i < attribute_args.size(); ++i) {
decl_stream << ", " << attribute_args[i];
}
decl_stream << ");";
ret.decl = func_name + decl_stream.str();
return ret;
} }
/*! \brief The id of the external dnnl ext_func. */ /*! \brief The id of the external dnnl ext_func. */
...@@ -219,6 +330,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -219,6 +330,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
* output to a buffer that may be consumed by other kernels. * output to a buffer that may be consumed by other kernels.
*/ */
int buf_idx_{0}; int buf_idx_{0};
/*! \brief The index of global constants. */
int const_idx_ = 0;
/*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
Array<Var> ext_func_args_; Array<Var> ext_func_args_;
/*! \brief statement of the function that will be compiled using DNNL kernels. */ /*! \brief statement of the function that will be compiled using DNNL kernels. */
...@@ -227,6 +340,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -227,6 +340,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
std::vector<std::string> buf_decl_; std::vector<std::string> buf_decl_;
/*! \brief The name of the the outputs. */ /*! \brief The name of the the outputs. */
std::vector<Output> out_; std::vector<Output> out_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
}; };
/*! /*!
......
...@@ -25,18 +25,19 @@ ...@@ -25,18 +25,19 @@
#define TVM_RELAY_BACKEND_UTILS_H_ #define TVM_RELAY_BACKEND_UTILS_H_
#include <dmlc/json.h> #include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/driver/driver_api.h> #include <tvm/relay/type.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/tir/ir_pass.h>
#include <typeinfo>
#include <string> #include <string>
#include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -59,7 +60,7 @@ inline const PackedFunc* GetPackedFunc(const std::string& func_name) { ...@@ -59,7 +60,7 @@ inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
*/ */
template <typename R, typename... Args> template <typename R, typename... Args>
inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) { inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) {
auto *pf = GetPackedFunc(func_name); auto* pf = GetPackedFunc(func_name);
CHECK(pf != nullptr) << "can not find packed function"; CHECK(pf != nullptr) << "can not find packed function";
return runtime::TypedPackedFunc<R(Args...)>(*pf); return runtime::TypedPackedFunc<R(Args...)>(*pf);
} }
...@@ -90,9 +91,8 @@ inline std::string DType2String(const tvm::DataType dtype) { ...@@ -90,9 +91,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
* \param params params dict * \param params params dict
* \return relay::Function * \return relay::Function
*/ */
inline relay::Function inline relay::Function BindParamsByName(
BindParamsByName(relay::Function func, relay::Function func, const std::unordered_map<std::string, runtime::NDArray>& params) {
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict; std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var; std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) { for (auto arg : func->params) {
...@@ -122,8 +122,64 @@ BindParamsByName(relay::Function func, ...@@ -122,8 +122,64 @@ BindParamsByName(relay::Function func,
return ret; return ret;
} }
/*!
* \brief Extract the shape from a Relay tensor type.
* \param type The provided type.
* \return The extracted shape in a list.
*/
inline std::vector<int> GetShape(const Type& type) {
const auto* ttype = type.as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
return shape;
}
/*!
* \brief Check if a call has the provided name.
* \param call A Relay call node.
* \param op_name The name of the expected call.
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
inline bool IsOp(const CallNode* call, const std::string& op_name) {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}
/*!
* \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d))
* \param call A Relay call node. Typically nn.relu when called the first time.
* \param depth The number of calls before the root op, counting from current_call.
* \param expected_op_names The names of ops in this fused call. Example: {"nn.conv2d", "add",
* "nn.relu"}
* \return A CallNode corresponding to the root op, whose name is expected_op_names[0]
*/
inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
const std::vector<std::string>& expected_op_names) {
CHECK(current_call && depth >= 0 && static_cast<size_t>(depth) < expected_op_names.size() &&
IsOp(current_call, expected_op_names[depth]));
if (depth == 0) {
return current_call;
}
CHECK_GT(current_call->args.size(), 0);
const auto* next_call = current_call->args[0].as<CallNode>();
return GetRootCall(next_call, depth - 1, expected_op_names);
}
} // namespace backend } // namespace backend
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_BACKEND_UTILS_H_ #endif // TVM_RELAY_BACKEND_UTILS_H_
...@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this // Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation // pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op` // pass. This is because memory allocation pass will insert `invoke_tvm_op`
...@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen. // external codegen.
pass_seqs.push_back(transform::Inline()); pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions. // Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
...@@ -52,10 +52,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) { ...@@ -52,10 +52,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) {
std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle)); std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle));
} }
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr) {
int p_Sh_, int p_Sw_) {
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
engine eng(engine::kind::cpu, 0); engine eng(engine::kind::cpu, 0);
...@@ -65,21 +64,15 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -65,21 +64,15 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_}; memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_};
if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_}; if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_};
memory::dims conv2d_bias_tz = {p_O_}; memory::dims conv2d_bias_tz = {p_O_};
memory::dims conv2d_dst_tz = {p_N_, p_O_, memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_}; (p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_};
memory::dims conv2d_strides = {p_Sh_, p_Sw_}; memory::dims conv2d_strides = {p_Sh_, p_Sw_};
memory::dims conv2d_padding = {p_Ph_, p_Pw_}; memory::dims conv2d_padding = {p_Ph_, p_Pw_};
std::vector<float> conv2d_bias(p_O_, 0); auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_weights_memory =
auto user_src_memory = memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, weights);
memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);
auto user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng,
weights);
auto conv2d_user_bias_memory =
memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data());
auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
...@@ -87,10 +80,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -87,10 +80,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw); auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw);
auto conv2d_desc = convolution_forward::desc( auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct, prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md, conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding, conv2d_padding);
conv2d_strides, conv2d_padding, conv2d_padding); auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng);
auto conv2d_src_memory = user_src_memory; auto conv2d_src_memory = user_src_memory;
auto conv2d_weights_memory = user_weights_memory; auto conv2d_weights_memory = user_weights_memory;
...@@ -105,6 +97,42 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -105,6 +97,42 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
read_from_dnnl_memory(out, conv2d_dst_memory); read_from_dnnl_memory(out, conv2d_dst_memory);
} }
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Kw_, int p_Sh_, int p_Sw_) {
primitive_attr attr;
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr);
}
primitive_attr create_attr_with_relu_post_op() {
post_ops ops;
ops.append_eltwise(1.f, algorithm::eltwise_relu, 0.f, 0.f);
primitive_attr attr;
attr.set_post_ops(ops);
return attr;
}
extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) {
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op());
}
extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out,
int p_N_, int p_C_, int p_H_, int p_W_, int p_O_,
int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_) {
return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph_,
p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op());
}
extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_,
int p_I_, int p_O_) { int p_I_, int p_O_) {
using tag = memory::format_tag; using tag = memory::format_tag;
...@@ -169,8 +197,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, ...@@ -169,8 +197,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory); read_from_dnnl_memory(out, dst_memory);
} }
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
float* variance, float* out, int p_N_, int p_C_, float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) { int p_H_, int p_W_, int p_E_) {
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
......
...@@ -38,14 +38,25 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int ...@@ -38,14 +38,25 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_); int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias,
float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_,
int p_Pw_, int p_Kh_, int p_Kw_, int p_Sh_,
int p_Sw_);
extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_,
int p_O_); int p_O_);
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_, float* variance, float* out, float* new_mean, float* new_variance,
int p_e_); int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_); int p_h_, int p_w_);
......
...@@ -27,11 +27,10 @@ from tvm import relay ...@@ -27,11 +27,10 @@ from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay import transform
from tvm.relay.backend import compile_engine from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.runtime import container from tvm.relay.build_module import bind_params_by_name
# Leverage the pass manager to write a simple white list based annotator # Leverage the pass manager to write a simple white list based annotator
...@@ -456,7 +455,7 @@ def test_extern_dnnl_mobilenet(): ...@@ -456,7 +455,7 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload( mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32') batch_size=1, dtype='float32')
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) mod["main"] = bind_params_by_name(mod["main"], params)
mod = transform.AnnotateTarget(["dnnl"])(mod) mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod) mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
...@@ -663,7 +662,7 @@ def test_constant_propagation(): ...@@ -663,7 +662,7 @@ def test_constant_propagation():
add = x + y add = x + y
log = relay.log(add) log = relay.log(add)
f = relay.Function([x, y], log) f = relay.Function([x, y], log)
f = relay.build_module.bind_params_by_name(f, {"x": tvm.nd.array(ones)}) f = bind_params_by_name(f, {"x": tvm.nd.array(ones)})
mod = tvm.IRModule() mod = tvm.IRModule()
mod["main"] = f mod["main"] = f
mod = WhiteListAnnotator(["add"], "ccompiler")(mod) mod = WhiteListAnnotator(["add"], "ccompiler")(mod)
...@@ -852,6 +851,128 @@ def test_mixed_single_multiple_outputs(): ...@@ -852,6 +851,128 @@ def test_mixed_single_multiple_outputs():
partitioned = transform.PartitionGraph()(mod) partitioned = transform.PartitionGraph()(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_dnnl_fuse():
def make_pattern(with_bias=True):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight")
bias = relay.var("bias")
conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=8, padding=(1, 1))
if with_bias:
conv_out = relay.add(conv, bias)
else:
conv_out = conv
return relay.nn.relu(conv_out)
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
def get_blocks(prefix, data, in_channel, out_channel,
include_bn=True, include_sigmoid=False):
weight = relay.var(prefix + "weight")
bn_gamma = relay.var(prefix + "bn_gamma")
bn_beta = relay.var(prefix + "bn_beta")
bn_mmean = relay.var(prefix + "bn_mean")
bn_mvar = relay.var(prefix + "bn_var")
layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=out_channel, padding=(1, 1))
if include_bn:
bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta,
bn_mmean, bn_mvar)
layer = bn_output[0]
if include_sigmoid:
# dummy layer to prevent pattern detection
layer = relay.sigmoid(layer)
layer = relay.nn.relu(layer)
return layer
def get_net(include_bn=True, include_sigmoid=False):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
block1 = get_blocks("block1_", data, 3, 8, include_bn, include_sigmoid)
# The second block is always conv + relu, to make it more interesting
block2 = get_blocks("block2_", block1, 8, 8, False, include_sigmoid)
return relay.Function(relay.analysis.free_vars(block2), block2)
def get_partitoned_mod(mod, params, pattern_table):
# This is required for constant folding
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = transform.Sequential([
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
])
composite_partition = transform.Sequential([
remove_bn_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
transform.PartitionGraph()
])
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
return composite_partition(mod)
def test_detect_pattern(pattern_table, include_bn, include_sigmoid,
num_expected_partition):
net = get_net(include_bn, include_sigmoid)
mod, params = tvm.relay.testing.create_workload(net)
mod = get_partitoned_mod(mod, params, pattern_table)
assert(len(mod.functions) - 1 == num_expected_partition) # -1 for main
def test_partition():
# conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
test_detect_pattern([conv2d_bias_relu_pat], True, False, 3)
# conv + bn + relu, conv + relu -> conv, bias, relu, and fused conv_relu
test_detect_pattern([conv2d_relu_pat], True, False, 4)
# conv + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, False, 2)
# conv + relu, conv + relu -> two fused conv_relu
test_detect_pattern([conv2d_relu_pat], False, False, 2)
# conv + relu, conv + relu -> no fusion, 4 partition each with a single op
test_detect_pattern([conv2d_bias_relu_pat], False, False, 4)
# conv + bn + sigmoid + relu, conv + sigmoid + relu -> no fusion
test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, 5)
def test_partition_mobilenet():
mod, params = relay.testing.mobilenet.get_workload()
mod = get_partitoned_mod(mod, params, dnnl_patterns)
# 27 fused conv + bn + relu and one dense
assert(len(mod.functions) - 1 == 28) # -1 for main
def test_exec(mod, params, ref_mod, ref_params, out_shape):
ishape = (1, 3, 224, 224)
i_data = np.random.randn(*ishape).astype(np.float32)
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
ref_res = ref_ex.evaluate()(i_data, **ref_params)
compile_engine.get().clear()
mod = get_partitoned_mod(mod, params, dnnl_patterns)
check_result(mod, {"data": i_data},
out_shape, ref_res.asnumpy(), tol=1e-5, params=params)
test_partition()
test_partition_mobilenet()
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
net = get_net()
mod, params = tvm.relay.testing.create_workload(net)
ref_mod, ref_params = tvm.relay.testing.create_workload(net)
test_exec(mod, params, ref_mod, ref_params, (1, 8, 224, 224))
# exec test on mobilenet is not possible due to manually inlined constants
# mod, params = relay.testing.mobilenet.get_workload()
# ref_mod, ref_params = relay.testing.mobilenet.get_workload()
# test_exec(mod, params, ref_mod, ref_params, (1, 1000))
if __name__ == "__main__": if __name__ == "__main__":
test_multi_node_compiler() test_multi_node_compiler()
test_extern_ccompiler_single_op() test_extern_ccompiler_single_op()
...@@ -865,3 +986,4 @@ if __name__ == "__main__": ...@@ -865,3 +986,4 @@ if __name__ == "__main__":
test_constant_propagation() test_constant_propagation()
test_multiple_outputs() test_multiple_outputs()
test_mixed_single_multiple_outputs() test_mixed_single_multiple_outputs()
test_dnnl_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment