Unverified Commit f0f03647 by Cody Yu Committed by GitHub

[BYOC] Refine DNNL Codegen (#5288)

* Improve DNNL

* Add bind params

* trigger ci
parent 6ecfaaff
......@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
......@@ -53,12 +53,19 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
void VisitExpr_(const TupleGetItemNode* op) final {
// Do nothing
VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));
// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
}
void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;
std::ostringstream buf_stream;
// Args: ID
std::vector<std::string> args;
......@@ -96,30 +103,38 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
}
// Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
// Analyze the output buffers
std::vector<Type> out_types;
if (call->checked_type()->IsInstance<TupleTypeNode>()) {
auto type_node = call->checked_type().as<TupleTypeNode>();
for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
out_types.push_back(field);
}
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
}
out_.clear();
for (auto out_type : out_types) {
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
auto out_shape = GetShape(out_type);
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
std::ostringstream buf_stream;
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
......@@ -128,6 +143,14 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
out_.push_back(output);
}
// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
}
std::string JIT(void) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
}
......
......@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
......@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen.
pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
......@@ -169,9 +169,11 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory);
}
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_N_, int p_C_,
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) {
// FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update
// the rest two because no one cares about them for now. Should update it in the future.
using tag = memory::format_tag;
using dt = memory::data_type;
......
......@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
int p_e_);
float* variance, float* out, float* new_mean, float* new_variance,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_);
......
......@@ -161,7 +161,7 @@ def test_extern_dnnl():
test_annotate()
test_run()
@pytest.mark.skip(reason="fix constant node before opening this case")
def test_extern_dnnl_mobilenet():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
......@@ -172,6 +172,7 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod = transform.AnnotateTarget("dnnl")(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
......@@ -267,5 +268,5 @@ def test_composite_function():
if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
#test_extern_dnnl_mobilenet()
test_composite_function()
......@@ -18,6 +18,7 @@
import os
import sys
import numpy as np
import pytest
import tvm
import tvm.relay.testing
......@@ -438,7 +439,7 @@ def test_extern_dnnl():
check_result(mod, {"data": i_data, "weight1": w1_data},
(1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
@pytest.mark.skip(reason="fix constant node before opening this case")
def test_extern_dnnl_mobilenet():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
......@@ -450,6 +451,7 @@ def test_extern_dnnl_mobilenet():
batch_size=1, dtype='float32')
op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod = WhiteListAnnotator(op_list, "dnnl")(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
......@@ -862,7 +864,7 @@ if __name__ == "__main__":
test_extern_ccompiler_default_ops()
test_extern_ccompiler()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
#test_extern_dnnl_mobilenet()
test_function_lifting()
test_function_lifting_inline()
test_constant_propagation()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment