Unverified Commit 5958d60d by Zhi Committed by GitHub

[BYOC] Enhance partitioning and external codegen (#5310)

* Remove duplicated output args

* address comment

* fix codegen c

* improve comment

* VisitExprDefault_

* deduce type
parent fc75de9d
...@@ -40,35 +40,39 @@ using namespace backend; ...@@ -40,35 +40,39 @@ using namespace backend;
* purpose. Only several binary options are covered. Users * purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators. * may need to extend them to cover more operators.
*/ */
class CodegenC : public ExprVisitor, public CodegenCBase { class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
public: public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) final { std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}
std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
return {};
}
std::vector<Output> VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node)); ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output; Output output;
output.name = node->name_hint(); output.name = node->name_hint();
out_.push_back(output); return {output};
} }
void VisitExpr_(const ConstantNode* cn) final { std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
Constant constant = GetRef<Constant>(cn); // Note this is for demonstration purpose. ConstantNode doesn't necessarily
if (visited_.count(constant)) { // belong to calls. We need to revisit this when tuples come into play.
// Note this is for demostration purpose. ConstantNode doesn't necessarily
// belong to calls. We need to revisit this when tuples come into play.
out_.push_back(visited_[constant]);
return;
}
std::ostringstream decl_stream; std::ostringstream decl_stream;
std::ostringstream buf_stream; std::ostringstream buf_stream;
out_.clear();
Output output; Output output;
output.name = "const_" + std::to_string(const_idx_++); output.name = "const_" + std::to_string(const_idx_++);
out_.push_back(output);
visited_[constant] = output;
runtime::NDArray array = cn->data; runtime::NDArray array = cn->data;
const auto& shape = array.Shape(); const auto& shape = array.Shape();
...@@ -99,9 +103,11 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -99,9 +103,11 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
} }
buf_stream << "};"; buf_stream << "};";
ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
return {output};
} }
void VisitExpr_(const CallNode* call) final { std::vector<Output> VisitExpr_(const CallNode* call) final {
std::ostringstream macro_stream; std::ostringstream macro_stream;
std::ostringstream decl_stream; std::ostringstream decl_stream;
std::ostringstream buf_stream; std::ostringstream buf_stream;
...@@ -138,8 +144,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -138,8 +144,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
bool first = true; bool first = true;
decl_stream << func_name << "("; decl_stream << func_name << "(";
for (size_t i = 0; i < call->args.size(); ++i) { for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]); auto res = VisitExpr(call->args[i]);
for (auto out : out_) { for (auto out : res) {
if (!first) { if (!first) {
decl_stream << ", "; decl_stream << ", ";
} }
...@@ -162,13 +168,14 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -162,13 +168,14 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
ext_func_body.push_back(decl_stream.str()); ext_func_body.push_back(decl_stream.str());
// Update output buffer // Update output buffer
out_.clear(); // Note C codegen only handles TensorType. Therefore, we don't flatten
// tuples and only return a single vaule.
Output output; Output output;
output.name = out; output.name = out;
output.dtype = dtype; output.dtype = dtype;
output.need_copy = true; output.need_copy = true;
output.size = out_size; output.size = out_size;
out_.push_back(output); return {output};
} }
/*! /*!
...@@ -176,12 +183,12 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -176,12 +183,12 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
* *
* \return The emitted code. * \return The emitted code.
*/ */
std::string JIT() { std::string JIT(const std::vector<Output>& out) {
// Write function macros // Write function macros
for (auto decl : func_decl_) { for (auto decl : func_decl_) {
code_stream_ << decl << "\n"; code_stream_ << decl << "\n";
} }
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_); return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
} }
private: private:
...@@ -202,9 +209,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -202,9 +209,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
/*! \brief The declaration statements of buffers. */ /*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_; std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */ /*! \brief The name and index pairs for output. */
std::vector<Output> out_; std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
}; };
class CSourceCodegen : public CSourceModuleCodegenBase { class CSourceCodegen : public CSourceModuleCodegenBase {
...@@ -216,8 +221,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase { ...@@ -216,8 +221,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
auto sid = GetExtSymbol(func); auto sid = GetExtSymbol(func);
CodegenC builder(sid); CodegenC builder(sid);
builder.VisitExpr(func->body); auto out = builder.VisitExpr(func->body);
code_stream_ << builder.JIT(); code_stream_ << builder.JIT(out);
} }
runtime::Module CreateCSourceModule(const ObjectRef& ref) override { runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
......
...@@ -165,9 +165,11 @@ class CodegenCBase { ...@@ -165,9 +165,11 @@ class CodegenCBase {
/*! /*!
* \brief Emit the code for external runtime. * \brief Emit the code for external runtime.
* *
* \param out The outputs.
*
* \return The code string. * \return The code string.
*/ */
virtual std::string JIT() = 0; virtual std::string JIT(const std::vector<Output>& out) = 0;
/*! /*!
* \brief A common interface that is used by various external runtime to * \brief A common interface that is used by various external runtime to
......
...@@ -128,42 +128,43 @@ std::vector<std::string> Add(const CallNode* call) { ...@@ -128,42 +128,43 @@ std::vector<std::string> Add(const CallNode* call) {
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement // TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement. // all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprVisitor, public CodegenCBase { class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
public: public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; } explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) final { std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}
std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey();
return {};
}
std::vector<Output> VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node)); ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output; Output output;
output.name = node->name_hint(); output.name = node->name_hint();
out_.push_back(output); return {output};
} }
void VisitExpr_(const TupleGetItemNode* op) final { std::vector<Output> VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple); auto res = VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index)); CHECK_GT(res.size(), static_cast<size_t>(op->index));
// Only keep the item we want for the child node. // Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs. // FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index]; return {res[op->index]};
out_.clear();
out_.push_back(item);
} }
void VisitExpr_(const ConstantNode* cn) final { std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
Constant constant = GetRef<Constant>(cn);
if (visited_.count(constant)) {
out_.push_back(visited_[constant]);
return;
}
out_.clear();
Output output; Output output;
output.name = "const_" + std::to_string(const_idx_++); output.name = "const_" + std::to_string(const_idx_++);
output.dtype = "float"; output.dtype = "float";
out_.push_back(output);
visited_[constant] = output;
runtime::NDArray array = cn->data; runtime::NDArray array = cn->data;
...@@ -176,16 +177,23 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -176,16 +177,23 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now."; CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
std::ostringstream buf_stream; std::ostringstream buf_stream;
buf_stream << "float* " << output.name << " = (float*)std::malloc(4 * " << num_elems << ");\n";
const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data); const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
for (int64_t i = 0; i < num_elems; i++) {
buf_stream << " " << output.name << "[" << i << "] = " << ptr[i] << ";\n"; // Allocate large arrays on the static section to avoid stakc overflow.
// Note that this would probably increase compilation time as the source
// file could be really large.
buf_stream << "static float " << output.name << "[" << num_elems <<"] = {";
for (int64_t i = 0; i < num_elems - 1; i++) {
buf_stream << ptr[i] << ",";
} }
if (num_elems > 0) buf_stream << ptr[num_elems - 1];
buf_stream << "};\n";
ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
return {output};
} }
void VisitExpr_(const CallNode* call) final { std::vector<Output> VisitExpr_(const CallNode* call) final {
GenerateBodyOutput ret; GenerateBodyOutput ret;
if (const auto* func = call->op.as<FunctionNode>()) { if (const auto* func = call->op.as<FunctionNode>()) {
ret = GenerateCompositeFunctionCall(func, call); ret = GenerateCompositeFunctionCall(func, call);
...@@ -193,16 +201,13 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -193,16 +201,13 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
ret = GenerateOpCall(call); ret = GenerateOpCall(call);
} }
out_.clear(); buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end());
for (size_t i = 0; i < ret.outputs.size(); ++i) {
buf_decl_.push_back(ret.buffers[i]);
out_.push_back(ret.outputs[i]);
}
ext_func_body.push_back(ret.decl); ext_func_body.push_back(ret.decl);
return ret.outputs;
} }
std::string JIT(void) { std::string JIT(const std::vector<Output>& out) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_); return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
} }
private: private:
...@@ -215,8 +220,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -215,8 +220,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
std::vector<std::string> GetArgumentNames(const CallNode* call) { std::vector<std::string> GetArgumentNames(const CallNode* call) {
std::vector<std::string> arg_names; std::vector<std::string> arg_names;
for (size_t i = 0; i < call->args.size(); ++i) { for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]); auto res = VisitExpr(call->args[i]);
for (auto out : out_) { for (const auto& out : res) {
arg_names.push_back(out.name); arg_names.push_back(out.name);
} }
} }
...@@ -331,17 +336,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -331,17 +336,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
*/ */
int buf_idx_{0}; int buf_idx_{0};
/*! \brief The index of global constants. */ /*! \brief The index of global constants. */
int const_idx_ = 0; int const_idx_{0};
/*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
Array<Var> ext_func_args_; Array<Var> ext_func_args_;
/*! \brief statement of the function that will be compiled using DNNL kernels. */ /*! \brief statement of the function that will be compiled using DNNL kernels. */
std::vector<std::string> ext_func_body; std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */ /*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_; std::vector<std::string> buf_decl_;
/*! \brief The name of the the outputs. */
std::vector<Output> out_;
/*! \brief The cached expressions. */ /*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_; std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
}; };
/*! /*!
...@@ -361,8 +364,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { ...@@ -361,8 +364,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
auto sid = GetExtSymbol(func); auto sid = GetExtSymbol(func);
CodegenDNNL builder(sid); CodegenDNNL builder(sid);
builder.VisitExpr(func->body); auto out = builder.VisitExpr(func->body);
code_stream_ << builder.JIT(); code_stream_ << builder.JIT(out);
} }
/*! /*!
......
...@@ -148,25 +148,42 @@ class Partitioner : public ExprMutator { ...@@ -148,25 +148,42 @@ class Partitioner : public ExprMutator {
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph. // Traverse the rest graph.
auto input_expr = VisitExpr(call->args[0]); Expr parent = call->args[0];
auto input_expr = VisitExpr(parent);
// Backtrace the parent to find the first ancestor node that is not a begin or end op
while (const auto* parent_call = parent.as<CallNode>()) {
if (parent_call->op == compiler_begin_op ||
parent_call->op == compiler_end_op) {
parent = parent_call->args[0];
} else {
break;
}
}
AnnotatedRegion sg = GetRegion(GetRef<Call>(call)); AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
int index = GetArgIdx(sg, GetRef<Call>(call)); int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1); CHECK_NE(index, -1);
// The type of the created variable is the same as the compiler_begin
// node.
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string varname =
target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_);
auto cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
region_args[sg].end()) {
region_args[sg].push_back(cand);
}
return std::move(var); if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
return shared_output_[parent][sg];
} else {
// The type of the created variable is the same as the compiler_begin
// node.
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string varname =
target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_);
std::pair<Var, Expr> cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
region_args[sg].end()) {
region_args[sg].push_back(cand);
}
shared_output_[parent][sg] = var;
return std::move(var);
}
} else { } else {
CHECK_EQ(call->op, compiler_end_op); CHECK_EQ(call->op, compiler_end_op);
// The annotation node is inserted on edge so it must have only one // The annotation node is inserted on edge so it must have only one
...@@ -474,6 +491,12 @@ class Partitioner : public ExprMutator { ...@@ -474,6 +491,12 @@ class Partitioner : public ExprMutator {
* belongs to * belongs to
*/ */
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_; std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
/*!\brief Cache the output that is shared by different nodes. */
using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;
/*!\brief The IRModule used for partitioning. */
IRModule module_; IRModule module_;
}; };
......
...@@ -300,6 +300,14 @@ def test_extern_ccompiler_single_op(): ...@@ -300,6 +300,14 @@ def test_extern_ccompiler_single_op():
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
def set_func_attr(func, compile_name, symbol_name):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", compile_name)
func = func.with_attr("global_symbol", symbol_name)
return func
def test_extern_ccompiler_default_ops(): def test_extern_ccompiler_default_ops():
def expected(): def expected():
mod = tvm.IRModule() mod = tvm.IRModule()
...@@ -310,10 +318,7 @@ def test_extern_ccompiler_default_ops(): ...@@ -310,10 +318,7 @@ def test_extern_ccompiler_default_ops():
add = x0 + y0 add = x0 + y0
# Function that uses C compiler # Function that uses C compiler
func = relay.Function([x0, y0], add) func = relay.Function([x0, y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = set_func_attr(func, "ccompiler", "ccompiler_0")
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", "ccompiler")
func = func.with_attr("global_symbol", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y]) add_call = relay.Call(glb_0, [x, y])
...@@ -380,32 +385,28 @@ def test_extern_dnnl(): ...@@ -380,32 +385,28 @@ def test_extern_dnnl():
def expected(): def expected():
data0 = relay.var("data", shape=(ishape), dtype=dtype) data0 = relay.var("data", shape=(ishape), dtype=dtype)
input0 = relay.var("input0", shape=(w1shape), dtype=dtype) input0 = relay.var("input", shape=(w1shape), dtype=dtype)
input1 = relay.var("input1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data0, depthwise_conv2d_1 = relay.nn.conv2d(data0,
input0, input0,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
groups=32) groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
input1, input0,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
groups=32) groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
func = relay.Function([data0, input0, input1], out) func = relay.Function([data0, input0], out)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = set_func_attr(func, "dnnl", "dnnl_0")
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", "dnnl")
func = func.with_attr("global_symbol", "dnnl_0")
glb_var = relay.GlobalVar("dnnl_0") glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule() mod = tvm.IRModule()
mod[glb_var] = func mod[glb_var] = func
data = relay.var("data", shape=(ishape), dtype=dtype) data = relay.var("data", shape=(ishape), dtype=dtype)
weight = relay.var("input", shape=(w1shape), dtype=dtype) weight = relay.var("input", shape=(w1shape), dtype=dtype)
main_f = relay.Function([data, weight], glb_var(data, weight, weight)) main_f = relay.Function([data, weight], glb_var(data, weight))
mod["main"] = main_f mod["main"] = main_f
return mod return mod
...@@ -444,7 +445,7 @@ def test_extern_dnnl(): ...@@ -444,7 +445,7 @@ def test_extern_dnnl():
check_result(mod, {"data": i_data, "weight1": w1_data}, check_result(mod, {"data": i_data, "weight1": w1_data},
(1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
@pytest.mark.skip(reason="fix constant node before opening this case")
def test_extern_dnnl_mobilenet(): def test_extern_dnnl_mobilenet():
if not tvm.get_global_func("relay.ext.dnnl", True): if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available") print("skip because DNNL codegen is not available")
...@@ -521,10 +522,7 @@ def test_function_lifting(): ...@@ -521,10 +522,7 @@ def test_function_lifting():
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple()) bn.astuple())
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = set_func_attr(func0, "test_compiler", "test_compiler_0")
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", "test_compiler")
func0 = func0.with_attr("global_symbol", "test_compiler_0")
gv0 = relay.GlobalVar("test_compiler_0") gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0 mod[gv0] = func0
...@@ -538,10 +536,7 @@ def test_function_lifting(): ...@@ -538,10 +536,7 @@ def test_function_lifting():
channels=16, channels=16,
padding=(1, 1)) padding=(1, 1))
func1 = relay.Function([data1, weight1], conv) func1 = relay.Function([data1, weight1], conv)
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func1 = set_func_attr(func1, "test_compiler", "test_compiler_1")
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler", "test_compiler")
func1 = func1.with_attr("global_symbol", "test_compiler_1")
gv1 = relay.GlobalVar("test_compiler_1") gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1 mod[gv1] = func1
...@@ -610,10 +605,7 @@ def test_function_lifting_inline(): ...@@ -610,10 +605,7 @@ def test_function_lifting_inline():
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple()) bn.astuple())
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = set_func_attr(func0, "test_compiler", "test_compiler_0")
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", "test_compiler")
func0 = func0.with_attr("global_symbol", "test_compiler_0")
# main function # main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
...@@ -645,10 +637,7 @@ def test_constant_propagation(): ...@@ -645,10 +637,7 @@ def test_constant_propagation():
add = x0 + y0 add = x0 + y0
# Function that uses C compiler # Function that uses C compiler
func = relay.Function([y0], add) func = relay.Function([y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = set_func_attr(func, "ccompiler", "ccompiler_0")
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", "ccompiler")
func = func.with_attr("global_symbol", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func mod[glb_0] = func
add_call = relay.Call(glb_0, [y]) add_call = relay.Call(glb_0, [y])
...@@ -745,10 +734,7 @@ def test_multiple_outputs(): ...@@ -745,10 +734,7 @@ def test_multiple_outputs():
func0 = relay.Function([data, weight, bn_gamma, bn_beta, func0 = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], tuple_o) bn_mean, bn_var], tuple_o)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = set_func_attr(func0, "test_target", "test_target_2")
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", "test_target")
func0 = func0.with_attr("global_symbol", "test_target_2")
gv0 = relay.GlobalVar("test_target_2") gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0 mod[gv0] = func0
...@@ -810,11 +796,7 @@ def test_mixed_single_multiple_outputs(): ...@@ -810,11 +796,7 @@ def test_mixed_single_multiple_outputs():
f1_O_2 = relay.nn.relu(f1_O_1) f1_O_2 = relay.nn.relu(f1_O_1)
f1_out = relay.Tuple((f1_O_2, f1_O_1)) f1_out = relay.Tuple((f1_O_2, f1_O_1))
func1 = relay.Function([f1_cb1], f1_out) func1 = relay.Function([f1_cb1], f1_out)
func1 = set_func_attr(func1, "test_target", "test_target_1")
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler", "test_target")
func1 = func1.with_attr("global_symbol", "test_target_1")
gv1 = relay.GlobalVar("test_target_1") gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1 mod[gv1] = func1
...@@ -823,11 +805,7 @@ def test_mixed_single_multiple_outputs(): ...@@ -823,11 +805,7 @@ def test_mixed_single_multiple_outputs():
f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10)) f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10))
f2_O_3 = relay.add(f2_cb3, f2_cb4) f2_O_3 = relay.add(f2_cb3, f2_cb4)
func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3)
func0 = set_func_attr(func0, "test_target", "test_target_0")
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", "test_target")
func0 = func0.with_attr("global_symbol", "test_target_0")
gv0 = relay.GlobalVar("test_target_0") gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0 mod[gv0] = func0
...@@ -967,10 +945,96 @@ def test_dnnl_fuse(): ...@@ -967,10 +945,96 @@ def test_dnnl_fuse():
ref_mod, ref_params = tvm.relay.testing.create_workload(net) ref_mod, ref_params = tvm.relay.testing.create_workload(net)
test_exec(mod, params, ref_mod, ref_params, (1, 8, 224, 224)) test_exec(mod, params, ref_mod, ref_params, (1, 8, 224, 224))
# exec test on mobilenet is not possible due to manually inlined constants mod, params = relay.testing.mobilenet.get_workload()
# mod, params = relay.testing.mobilenet.get_workload() ref_mod, ref_params = relay.testing.mobilenet.get_workload()
# ref_mod, ref_params = relay.testing.mobilenet.get_workload() test_exec(mod, params, ref_mod, ref_params, (1, 1000))
# test_exec(mod, params, ref_mod, ref_params, (1, 1000))
def test_multiple_use_of_an_output():
def expected_same_output_region():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
z = relay.var("z", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
y0 = relay.var("y0", shape=(8, 8))
log = relay.log(x0)
sub = x0 - y0
mul = log * sub
# The partitioned graph contains log, subtract, and multiply
func = relay.Function([x0, y0], mul)
func = set_func_attr(func, "ccompiler", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add = x + y
call = relay.Call(glb_0, [add, z])
main = relay.Function([x, y, z], call)
mod["main"] = main
return mod
def expected_different_output_region():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
z = relay.var("z", shape=(8, 8))
# The partitioned graph contains log
i0 = relay.var("i0", shape=(8, 8))
log = relay.log(i0)
func = relay.Function([i0], log)
func = set_func_attr(func, "ccompiler", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
# The partitioned graph contains subtract
x0 = relay.var("x0", shape=(8, 8))
y0 = relay.var("y0", shape=(8, 8))
sub = x0 - y0
func = relay.Function([x0, y0], sub)
func = set_func_attr(func, "ccompiler", "ccompiler_1")
glb_1 = relay.GlobalVar("ccompiler_1")
mod[glb_1] = func
add = x + y
call_log = relay.Call(glb_0, [add])
call_sub = relay.Call(glb_1, [add, z])
main = relay.Function([x, y, z], call_log * call_sub)
mod["main"] = main
return mod
def get_mod():
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
z = relay.var("z", shape=(8, 8))
add = x + y
sub = add - z
log = relay.log(add)
sub1 = log * sub
f = relay.Function([x, y, z], sub1)
mod = tvm.IRModule()
mod["main"] = f
return mod
def test_same_output_region():
mod = get_mod()
mod = WhiteListAnnotator(["subtract", "log", "multiply"], "ccompiler")(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
expected_mod = expected_same_output_region()
assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
def test_different_output_region():
mod = get_mod()
mod = WhiteListAnnotator(["subtract", "log"], "ccompiler")(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
expected_mod = expected_different_output_region()
assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
test_same_output_region()
test_different_output_region()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -979,11 +1043,11 @@ if __name__ == "__main__": ...@@ -979,11 +1043,11 @@ if __name__ == "__main__":
test_extern_ccompiler_default_ops() test_extern_ccompiler_default_ops()
test_extern_ccompiler() test_extern_ccompiler()
test_extern_dnnl() test_extern_dnnl()
# TODO(@comaniac, @zhiics): Fix constant node and re-open this case. test_extern_dnnl_mobilenet()
#test_extern_dnnl_mobilenet()
test_function_lifting() test_function_lifting()
test_function_lifting_inline() test_function_lifting_inline()
test_constant_propagation() test_constant_propagation()
test_multiple_outputs() test_multiple_outputs()
test_mixed_single_multiple_outputs() test_mixed_single_multiple_outputs()
test_dnnl_fuse() test_dnnl_fuse()
test_multiple_use_of_an_output()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment