Unverified Commit 3616ebee by masahi Committed by GitHub

[BYOC] Add example of Composite + Annotate for DNNL fused op (#5272)

* merge change from dev branch

* fix string issue

* bring comanic's change back
parent 4b27cd14
...@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True): ...@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d") _register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense") _register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu") _register_external_op_helper("nn.relu")
_register_external_op_helper("add") _register_external_op_helper("add")
_register_external_op_helper("subtract") _register_external_op_helper("subtract")
_register_external_op_helper("multiply") _register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
...@@ -19,19 +19,22 @@ ...@@ -19,19 +19,22 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "../../utils.h"
#include "codegen_c.h" #include "codegen_c.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace contrib { namespace contrib {
using namespace backend;
/*! /*!
* \brief An example codegen that is only used for quick prototyping and testing * \brief An example codegen that is only used for quick prototyping and testing
* purpose. Only several binary options are covered. Users * purpose. Only several binary options are covered. Users
......
...@@ -170,41 +170,6 @@ class CodegenCBase { ...@@ -170,41 +170,6 @@ class CodegenCBase {
virtual std::string JIT() = 0; virtual std::string JIT() = 0;
/*! /*!
* \brief Extract the shape from a Relay tensor type.
*
* \param type The provided type.
*
* \return The extracted shape in a list.
*/
std::vector<int> GetShape(const Type& type) const {
const auto* ttype = type.as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
return shape;
}
/*!
* \brief Check if a call has the provided name.
*
* \param call A Relay call node.
* \param op_name The name of the expected call.
*
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
bool IsOp(const CallNode* call, const std::string& op_name) const {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}
/*!
* \brief A common interface that is used by various external runtime to * \brief A common interface that is used by various external runtime to
* generate the wrapper to invoke external kernels. * generate the wrapper to invoke external kernels.
* *
......
...@@ -30,110 +30,24 @@ ...@@ -30,110 +30,24 @@
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <fstream> #include <fstream>
#include <numeric>
#include <sstream> #include <sstream>
#include "../../utils.h"
#include "../codegen_c/codegen_c.h" #include "../codegen_c/codegen_c.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace contrib { namespace contrib {
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement using namespace backend;
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprVisitor, public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output;
output.name = node->name_hint();
out_.push_back(output);
}
void VisitExpr_(const TupleGetItemNode* op) final {
// Do nothing
}
void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;
std::ostringstream buf_stream;
// Args: ID
std::vector<std::string> args;
// Get the arguments for various DNNL kernels.
if (IsOp(call, "nn.conv2d")) {
decl_stream << "dnnl_conv2d";
args = Conv2d(call);
} else if (IsOp(call, "nn.dense")) {
decl_stream << "dnnl_dense";
args = Dense(call);
} else if (IsOp(call, "nn.relu")) {
decl_stream << "dnnl_relu";
args = Relu(call);
} else if (IsOp(call, "nn.batch_norm")) {
decl_stream << "dnnl_bn";
args = BatchNorm(call);
} else if (IsOp(call, "add")) {
decl_stream << "dnnl_add";
args = Add(call);
} else {
LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
}
// Make function call with input buffers when visiting arguments
bool first = true;
decl_stream << "(";
for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]);
for (auto out : out_) {
if (!first) {
decl_stream << ", ";
}
first = false;
decl_stream << out.name;
}
}
// Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Attach attribute arguments inline size_t GetShape1DSize(const Type& type) {
for (size_t i = 0; i < args.size(); ++i) { const auto shape = GetShape(type);
decl_stream << ", " << args[i]; return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
} }
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
}
std::string JIT(void) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
}
private: std::vector<std::string> Conv2d(const CallNode* call) {
std::vector<std::string> Conv2d(const CallNode* call) {
std::vector<std::string> args; std::vector<std::string> args;
const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>(); const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
CHECK(conv2d_attr); CHECK(conv2d_attr);
...@@ -157,9 +71,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -157,9 +71,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value)); args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
return args; return args;
} }
std::vector<std::string> Dense(const CallNode* call) { std::vector<std::string> Dense(const CallNode* call) {
std::vector<std::string> args; std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type()); auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type()); auto wshape = GetShape(call->args[1]->checked_type());
...@@ -170,9 +84,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -170,9 +84,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
args.push_back(std::to_string(wshape[0])); args.push_back(std::to_string(wshape[0]));
return args; return args;
} }
std::vector<std::string> Relu(const CallNode* call) { std::vector<std::string> Relu(const CallNode* call) {
std::vector<std::string> args; std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type()); auto ishape = GetShape(call->args[0]->checked_type());
...@@ -182,9 +96,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -182,9 +96,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
return args; return args;
} }
std::vector<std::string> BatchNorm(const CallNode* call) { std::vector<std::string> BatchNorm(const CallNode* call) {
std::vector<std::string> args; std::vector<std::string> args;
const auto* bn_attr = call->attrs.as<BatchNormAttrs>(); const auto* bn_attr = call->attrs.as<BatchNormAttrs>();
auto ishape = GetShape(call->args[0]->checked_type()); auto ishape = GetShape(call->args[0]->checked_type());
...@@ -198,9 +112,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -198,9 +112,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
args.push_back(std::to_string(bn_attr->epsilon)); args.push_back(std::to_string(bn_attr->epsilon));
return args; return args;
} }
std::vector<std::string> Add(const CallNode* call) { std::vector<std::string> Add(const CallNode* call) {
std::vector<std::string> args; std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type()); auto ishape = GetShape(call->args[0]->checked_type());
...@@ -210,6 +124,203 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -210,6 +124,203 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
return args; return args;
}
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprVisitor, public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output;
output.name = node->name_hint();
out_.push_back(output);
}
void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));
// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
}
void VisitExpr_(const ConstantNode* cn) final {
Constant constant = GetRef<Constant>(cn);
if (visited_.count(constant)) {
out_.push_back(visited_[constant]);
return;
}
out_.clear();
Output output;
output.name = "const_" + std::to_string(const_idx_++);
output.dtype = "float";
out_.push_back(output);
visited_[constant] = output;
runtime::NDArray array = cn->data;
// Get the number of elements.
int64_t num_elems = 1;
for (auto i : array.Shape()) num_elems *= i;
const auto* type_node = cn->checked_type().as<TensorTypeNode>();
CHECK(type_node);
CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
std::ostringstream buf_stream;
buf_stream << "float* " << output.name << " = (float*)std::malloc(4 * " << num_elems << ");\n";
const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
for (int64_t i = 0; i < num_elems; i++) {
buf_stream << " " << output.name << "[" << i << "] = " << ptr[i] << ";\n";
}
ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
}
void VisitExpr_(const CallNode* call) final {
GenerateBodyOutput ret;
if (const auto* func = call->op.as<FunctionNode>()) {
ret = GenerateCompositeFunctionCall(func, call);
} else {
ret = GenerateOpCall(call);
}
out_.clear();
for (size_t i = 0; i < ret.outputs.size(); ++i) {
buf_decl_.push_back(ret.buffers[i]);
out_.push_back(ret.outputs[i]);
}
ext_func_body.push_back(ret.decl);
}
std::string JIT(void) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
}
private:
struct GenerateBodyOutput {
std::string decl;
std::vector<std::string> buffers;
std::vector<Output> outputs;
};
std::vector<std::string> GetArgumentNames(const CallNode* call) {
std::vector<std::string> arg_names;
for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]);
for (auto out : out_) {
arg_names.push_back(out.name);
}
}
return arg_names;
}
GenerateBodyOutput GenerateOpCall(const CallNode* call) {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expect OpNode, but got " << call->op->GetTypeKey();
using ArgFunType = std::function<std::vector<std::string>(const CallNode*)>;
static const std::map<std::string, std::pair<std::string, ArgFunType>> op_map = {
{"nn.conv2d", {"dnnl_conv2d", Conv2d}},
{"nn.dense", {"dnnl_dense", Dense}},
{"nn.relu", {"dnnl_relu", Relu}},
{"nn.batch_norm", {"dnnl_bn", BatchNorm}},
{"add", {"dnnl_add", Add}},
};
const auto op_name = GetRef<Op>(op_node)->name;
const auto iter = op_map.find(op_name);
if (iter != op_map.end()) {
return GenerateBody(call, iter->second.first, iter->second.second(call));
}
LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
return {};
}
GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
const CallNode* caller) {
const auto pattern_name = callee->GetAttr<runtime::String>(attr::kComposite);
CHECK(pattern_name.defined()) << "Only functions with composite attribute supported";
if (pattern_name == "dnnl.conv2d_bias_relu") {
const auto* conv_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_bias_relu", GetArgumentNames(caller),
Conv2d(conv_call));
} else if (pattern_name == "dnnl.conv2d_relu") {
const auto* conv_call = GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller),
Conv2d(conv_call));
}
LOG(FATAL) << "Unknown composite function:" << pattern_name;
return {};
}
GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name,
const std::vector<std::string>& attribute_args) {
return GenerateBody(root_call, func_name, GetArgumentNames(root_call), attribute_args);
}
GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name,
const std::vector<std::string>& func_args,
const std::vector<std::string>& attribute_args) {
// Make function call with input buffers when visiting arguments
CHECK_GT(func_args.size(), 0);
std::ostringstream decl_stream;
decl_stream << "(" << func_args[0];
for (size_t i = 1; i < func_args.size(); ++i) {
decl_stream << ", " << func_args[i];
}
// Analyze the output buffers
std::vector<Type> out_types;
if (root_call->checked_type()->IsInstance<TupleTypeNode>()) {
auto type_node = root_call->checked_type().as<TupleTypeNode>();
for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
out_types.push_back(field);
}
} else if (root_call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(root_call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(root_call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(root_call->checked_type(), false);
}
GenerateBodyOutput ret;
for (const auto& out_type : out_types) {
this->PrintIndents();
const std::string out = "buf_" + std::to_string(buf_idx_++);
const auto out_size = GetShape1DSize(out_type);
decl_stream << ", " << out;
Output output;
output.name = out;
output.size = out_size;
output.dtype = GetDtypeString(out_type.as<TensorTypeNode>());
output.need_copy = true;
ret.buffers.push_back("float* " + out + " = (float*)std::malloc(4 * " +
std::to_string(out_size) + ");");
ret.outputs.push_back(output);
}
// Attach attribute arguments
for (size_t i = 0; i < attribute_args.size(); ++i) {
decl_stream << ", " << attribute_args[i];
}
decl_stream << ");";
ret.decl = func_name + decl_stream.str();
return ret;
} }
/*! \brief The id of the external dnnl ext_func. */ /*! \brief The id of the external dnnl ext_func. */
...@@ -219,6 +330,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -219,6 +330,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
* output to a buffer that may be consumed by other kernels. * output to a buffer that may be consumed by other kernels.
*/ */
int buf_idx_{0}; int buf_idx_{0};
/*! \brief The index of global constants. */
int const_idx_ = 0;
/*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
Array<Var> ext_func_args_; Array<Var> ext_func_args_;
/*! \brief statement of the function that will be compiled using DNNL kernels. */ /*! \brief statement of the function that will be compiled using DNNL kernels. */
...@@ -227,6 +340,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -227,6 +340,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
std::vector<std::string> buf_decl_; std::vector<std::string> buf_decl_;
/*! \brief The name of the the outputs. */ /*! \brief The name of the the outputs. */
std::vector<Output> out_; std::vector<Output> out_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
}; };
/*! /*!
......
...@@ -25,18 +25,19 @@ ...@@ -25,18 +25,19 @@
#define TVM_RELAY_BACKEND_UTILS_H_ #define TVM_RELAY_BACKEND_UTILS_H_
#include <dmlc/json.h> #include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/driver/driver_api.h> #include <tvm/relay/type.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/tir/ir_pass.h>
#include <typeinfo>
#include <string> #include <string>
#include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -59,7 +60,7 @@ inline const PackedFunc* GetPackedFunc(const std::string& func_name) { ...@@ -59,7 +60,7 @@ inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
*/ */
template <typename R, typename... Args> template <typename R, typename... Args>
inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) { inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) {
auto *pf = GetPackedFunc(func_name); auto* pf = GetPackedFunc(func_name);
CHECK(pf != nullptr) << "can not find packed function"; CHECK(pf != nullptr) << "can not find packed function";
return runtime::TypedPackedFunc<R(Args...)>(*pf); return runtime::TypedPackedFunc<R(Args...)>(*pf);
} }
...@@ -90,9 +91,8 @@ inline std::string DType2String(const tvm::DataType dtype) { ...@@ -90,9 +91,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
* \param params params dict * \param params params dict
* \return relay::Function * \return relay::Function
*/ */
inline relay::Function inline relay::Function BindParamsByName(
BindParamsByName(relay::Function func, relay::Function func, const std::unordered_map<std::string, runtime::NDArray>& params) {
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict; std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var; std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) { for (auto arg : func->params) {
...@@ -122,8 +122,64 @@ BindParamsByName(relay::Function func, ...@@ -122,8 +122,64 @@ BindParamsByName(relay::Function func,
return ret; return ret;
} }
/*!
* \brief Extract the shape from a Relay tensor type.
* \param type The provided type.
* \return The extracted shape in a list.
*/
inline std::vector<int> GetShape(const Type& type) {
const auto* ttype = type.as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
return shape;
}
/*!
* \brief Check if a call has the provided name.
* \param call A Relay call node.
* \param op_name The name of the expected call.
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
inline bool IsOp(const CallNode* call, const std::string& op_name) {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}
/*!
* \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d))
* \param call A Relay call node. Typically nn.relu when called the first time.
* \param depth The number of calls before the root op, counting from current_call.
* \param expected_op_names The names of ops in this fused call. Example: {"nn.conv2d", "add",
* "nn.relu"}
* \return A CallNode corresponding to the root op, whose name is expected_op_names[0]
*/
inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
const std::vector<std::string>& expected_op_names) {
CHECK(current_call && depth >= 0 && static_cast<size_t>(depth) < expected_op_names.size() &&
IsOp(current_call, expected_op_names[depth]));
if (depth == 0) {
return current_call;
}
CHECK_GT(current_call->args.size(), 0);
const auto* next_call = current_call->args[0].as<CallNode>();
return GetRootCall(next_call, depth - 1, expected_op_names);
}
} // namespace backend } // namespace backend
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_BACKEND_UTILS_H_ #endif // TVM_RELAY_BACKEND_UTILS_H_
...@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this // Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation // pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op` // pass. This is because memory allocation pass will insert `invoke_tvm_op`
...@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen. // external codegen.
pass_seqs.push_back(transform::Inline()); pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions. // Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
...@@ -52,10 +52,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) { ...@@ -52,10 +52,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) {
std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle)); std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle));
} }
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr) {
int p_Sh_, int p_Sw_) {
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
engine eng(engine::kind::cpu, 0); engine eng(engine::kind::cpu, 0);
...@@ -65,21 +64,15 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -65,21 +64,15 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_}; memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_};
if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_}; if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_};
memory::dims conv2d_bias_tz = {p_O_}; memory::dims conv2d_bias_tz = {p_O_};
memory::dims conv2d_dst_tz = {p_N_, p_O_, memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_}; (p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_};
memory::dims conv2d_strides = {p_Sh_, p_Sw_}; memory::dims conv2d_strides = {p_Sh_, p_Sw_};
memory::dims conv2d_padding = {p_Ph_, p_Pw_}; memory::dims conv2d_padding = {p_Ph_, p_Pw_};
std::vector<float> conv2d_bias(p_O_, 0); auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_weights_memory =
auto user_src_memory = memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, weights);
memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);
auto user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng,
weights);
auto conv2d_user_bias_memory =
memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data());
auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
...@@ -87,10 +80,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -87,10 +80,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw); auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw);
auto conv2d_desc = convolution_forward::desc( auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct, prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md, conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding, conv2d_padding);
conv2d_strides, conv2d_padding, conv2d_padding); auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng);
auto conv2d_src_memory = user_src_memory; auto conv2d_src_memory = user_src_memory;
auto conv2d_weights_memory = user_weights_memory; auto conv2d_weights_memory = user_weights_memory;
...@@ -105,6 +97,42 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, ...@@ -105,6 +97,42 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
read_from_dnnl_memory(out, conv2d_dst_memory); read_from_dnnl_memory(out, conv2d_dst_memory);
} }
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Kw_, int p_Sh_, int p_Sw_) {
primitive_attr attr;
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr);
}
primitive_attr create_attr_with_relu_post_op() {
post_ops ops;
ops.append_eltwise(1.f, algorithm::eltwise_relu, 0.f, 0.f);
primitive_attr attr;
attr.set_post_ops(ops);
return attr;
}
extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) {
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op());
}
extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out,
int p_N_, int p_C_, int p_H_, int p_W_, int p_O_,
int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_) {
return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph_,
p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op());
}
extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_,
int p_I_, int p_O_) { int p_I_, int p_O_) {
using tag = memory::format_tag; using tag = memory::format_tag;
...@@ -169,8 +197,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, ...@@ -169,8 +197,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory); read_from_dnnl_memory(out, dst_memory);
} }
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
float* variance, float* out, int p_N_, int p_C_, float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) { int p_H_, int p_W_, int p_E_) {
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
......
...@@ -38,14 +38,25 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int ...@@ -38,14 +38,25 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_); int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias,
float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_,
int p_Pw_, int p_Kh_, int p_Kw_, int p_Sh_,
int p_Sw_);
extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_,
int p_O_); int p_O_);
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_, float* variance, float* out, float* new_mean, float* new_variance,
int p_e_); int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_); int p_h_, int p_w_);
......
...@@ -27,11 +27,10 @@ from tvm import relay ...@@ -27,11 +27,10 @@ from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay import transform
from tvm.relay.backend import compile_engine from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.runtime import container from tvm.relay.build_module import bind_params_by_name
# Leverage the pass manager to write a simple white list based annotator # Leverage the pass manager to write a simple white list based annotator
...@@ -456,7 +455,7 @@ def test_extern_dnnl_mobilenet(): ...@@ -456,7 +455,7 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload( mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32') batch_size=1, dtype='float32')
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) mod["main"] = bind_params_by_name(mod["main"], params)
mod = transform.AnnotateTarget(["dnnl"])(mod) mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod) mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
...@@ -663,7 +662,7 @@ def test_constant_propagation(): ...@@ -663,7 +662,7 @@ def test_constant_propagation():
add = x + y add = x + y
log = relay.log(add) log = relay.log(add)
f = relay.Function([x, y], log) f = relay.Function([x, y], log)
f = relay.build_module.bind_params_by_name(f, {"x": tvm.nd.array(ones)}) f = bind_params_by_name(f, {"x": tvm.nd.array(ones)})
mod = tvm.IRModule() mod = tvm.IRModule()
mod["main"] = f mod["main"] = f
mod = WhiteListAnnotator(["add"], "ccompiler")(mod) mod = WhiteListAnnotator(["add"], "ccompiler")(mod)
...@@ -852,6 +851,128 @@ def test_mixed_single_multiple_outputs(): ...@@ -852,6 +851,128 @@ def test_mixed_single_multiple_outputs():
partitioned = transform.PartitionGraph()(mod) partitioned = transform.PartitionGraph()(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_dnnl_fuse():
def make_pattern(with_bias=True):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight")
bias = relay.var("bias")
conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=8, padding=(1, 1))
if with_bias:
conv_out = relay.add(conv, bias)
else:
conv_out = conv
return relay.nn.relu(conv_out)
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
def get_blocks(prefix, data, in_channel, out_channel,
include_bn=True, include_sigmoid=False):
weight = relay.var(prefix + "weight")
bn_gamma = relay.var(prefix + "bn_gamma")
bn_beta = relay.var(prefix + "bn_beta")
bn_mmean = relay.var(prefix + "bn_mean")
bn_mvar = relay.var(prefix + "bn_var")
layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=out_channel, padding=(1, 1))
if include_bn:
bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta,
bn_mmean, bn_mvar)
layer = bn_output[0]
if include_sigmoid:
# dummy layer to prevent pattern detection
layer = relay.sigmoid(layer)
layer = relay.nn.relu(layer)
return layer
def get_net(include_bn=True, include_sigmoid=False):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
block1 = get_blocks("block1_", data, 3, 8, include_bn, include_sigmoid)
# The second block is always conv + relu, to make it more interesting
block2 = get_blocks("block2_", block1, 8, 8, False, include_sigmoid)
return relay.Function(relay.analysis.free_vars(block2), block2)
def get_partitoned_mod(mod, params, pattern_table):
# This is required for constant folding
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = transform.Sequential([
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
])
composite_partition = transform.Sequential([
remove_bn_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
transform.PartitionGraph()
])
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
return composite_partition(mod)
def test_detect_pattern(pattern_table, include_bn, include_sigmoid,
num_expected_partition):
net = get_net(include_bn, include_sigmoid)
mod, params = tvm.relay.testing.create_workload(net)
mod = get_partitoned_mod(mod, params, pattern_table)
assert(len(mod.functions) - 1 == num_expected_partition) # -1 for main
def test_partition():
# conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
test_detect_pattern([conv2d_bias_relu_pat], True, False, 3)
# conv + bn + relu, conv + relu -> conv, bias, relu, and fused conv_relu
test_detect_pattern([conv2d_relu_pat], True, False, 4)
# conv + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, False, 2)
# conv + relu, conv + relu -> two fused conv_relu
test_detect_pattern([conv2d_relu_pat], False, False, 2)
# conv + relu, conv + relu -> no fusion, 4 partition each with a single op
test_detect_pattern([conv2d_bias_relu_pat], False, False, 4)
# conv + bn + sigmoid + relu, conv + sigmoid + relu -> no fusion
test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, 5)
def test_partition_mobilenet():
mod, params = relay.testing.mobilenet.get_workload()
mod = get_partitoned_mod(mod, params, dnnl_patterns)
# 27 fused conv + bn + relu and one dense
assert(len(mod.functions) - 1 == 28) # -1 for main
def test_exec(mod, params, ref_mod, ref_params, out_shape):
ishape = (1, 3, 224, 224)
i_data = np.random.randn(*ishape).astype(np.float32)
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
ref_res = ref_ex.evaluate()(i_data, **ref_params)
compile_engine.get().clear()
mod = get_partitoned_mod(mod, params, dnnl_patterns)
check_result(mod, {"data": i_data},
out_shape, ref_res.asnumpy(), tol=1e-5, params=params)
test_partition()
test_partition_mobilenet()
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
net = get_net()
mod, params = tvm.relay.testing.create_workload(net)
ref_mod, ref_params = tvm.relay.testing.create_workload(net)
test_exec(mod, params, ref_mod, ref_params, (1, 8, 224, 224))
# exec test on mobilenet is not possible due to manually inlined constants
# mod, params = relay.testing.mobilenet.get_workload()
# ref_mod, ref_params = relay.testing.mobilenet.get_workload()
# test_exec(mod, params, ref_mod, ref_params, (1, 1000))
if __name__ == "__main__": if __name__ == "__main__":
test_multi_node_compiler() test_multi_node_compiler()
test_extern_ccompiler_single_op() test_extern_ccompiler_single_op()
...@@ -865,3 +986,4 @@ if __name__ == "__main__": ...@@ -865,3 +986,4 @@ if __name__ == "__main__":
test_constant_propagation() test_constant_propagation()
test_multiple_outputs() test_multiple_outputs()
test_mixed_single_multiple_outputs() test_mixed_single_multiple_outputs()
test_dnnl_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment