Commit c44b7bf1 by Zhi Committed by Tianqi Chen

[Relay] External codegen (#4482)

parent 603280bf
......@@ -254,6 +254,8 @@ include(cmake/modules/LLVM.cmake)
......@@ -172,6 +172,9 @@ set(USE_ROCBLAS OFF)
# Whether use contrib sort
# Whether use MKL-DNN (DNNL) codegen
# Build ANTLR parser for Relay text format
# Possible values:
# - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
file(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/
find_library(EXTERN_LIBRARY_DNNL dnnl)
file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
message(STATUS "Build with DNNL codegen: " ${EXTERN_LIBRARY_DNNL})
......@@ -170,6 +170,9 @@ TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
TVM_DLL Target stackvm(const std::vector<std::string>& options =
/*! \return A target for external device */
TVM_DLL Target ext_dev(const std::vector<std::string>& options =
} // namespace target
......@@ -268,6 +268,15 @@ class FunctionNode : public ExprNode {
bool IsPrimitive() const;
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
bool UseDefaultCompiler() const;
TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
......@@ -588,6 +597,25 @@ std::string AsText(const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
} // namespace attr
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
......@@ -133,7 +133,16 @@ class Module(ModuleBase):
files = [path_obj]
is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
has_imported_c_file = False
if self.imported_modules:
for i, m in enumerate(self.imported_modules):
if m.type_key == "c":
has_imported_c_file = True
c_file_name = "tmp_" + str(i) + ".cc"
path_cc = temp.relpath(c_file_name)
with open(path_cc, "w") as f:
path_cc = temp.relpath("")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
......@@ -143,7 +152,7 @@ class Module(ModuleBase):
fcompile = _tar.tar
fcompile = _cc.create_shared
if self.type_key == "c":
if self.type_key == "c" or has_imported_c_file:
options = []
if "options" in kwargs:
opts = kwargs["options"]
......@@ -309,6 +309,10 @@ Target intel_graphics(const std::vector<std::string>& options) {
Target stackvm(const std::vector<std::string>& options) {
return CreateTarget("stackvm", options);
Target ext_dev(const std::vector<std::string>& options) {
return CreateTarget("ext_dev", options);
} // namespace target
bool LLVMEnabled() {
......@@ -69,6 +69,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
if (tkey == "c") continue;
// translate to C program
......@@ -73,6 +73,10 @@ struct GraphCodegen {
return CallFunc<std::string>("get_graph_json", nullptr);
Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr);
Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
......@@ -148,6 +152,10 @@ class RelayBuildModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc();
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetExternalModules();
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2);
......@@ -474,6 +482,20 @@ class RelayBuildModule : public runtime::ModuleNode {
Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
if (!ext_mods.empty()) {
CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1)
<< "Expect to have a TVM DSOModule when multiple external runtime modules exist";
if (lowered_funcs.size() == 0) {
// Execute the whole module using external runtime.
ret_.mod = ext_mods[0];
} else {
// Import all external runtime modules.
for (const auto& it : ext_mods) {
......@@ -27,6 +27,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
......@@ -608,6 +609,46 @@ class CompileEngineImpl : public CompileEngineNode {
return LowerShapeFuncInternal(key)->cached_func;
Array<tvm::runtime::Module> LowerExternalFunctions() {
std::unordered_map<std::string, relay::Module> ext_mods;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
if (!src_func->UseDefaultCompiler()) {
auto compiler = FunctionGetAttr(src_func, attr::kCompiler);
const tvm::ir::StringImm* code_gen =<tvm::ir::StringImm>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImm* symbol_name =<tvm::ir::StringImm>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
Array<tvm::runtime::Module> ret;
for (const auto& it : ext_mods) {
std::string ext_name = "relay.ext." + it.first;
auto pf = tvm::runtime::Registry::Get(ext_name);
CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
runtime::Module ext_mod = (*pf)(it.second);
CHECK(ext_mod.defined()) << "No external runtime is generated.";
// No need to cache external functions as we collected them all to create
// external runtime modules.
for (const auto& it : cached_ext_funcs) {
return ret;
void Clear() final {
......@@ -648,6 +689,18 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0;
cache_[key] = value;
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (!key->source_func->UseDefaultCompiler()) {
auto cache_node = make_node<CachedFuncNode>();
const auto name_node =
FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node);
return value;
// Enforce use the target.
With<Target> target_scope(key->target);
......@@ -759,42 +812,46 @@ const CompileEngine& CompileEngine::Global() {
return *inst;
.set_body_typed<CCacheKey(Function, Target)>(CCacheKeyNode::make);
.set_body_typed<CompileEngine()>([]() {
return CompileEngine::Global();
return CompileEngine::Global();
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->Lower(key);
return self->Lower(key);
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key);
return self->LowerShapeFunc(key);
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
return self->LowerExternalFunctions();
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->JIT(key);
return self->JIT(key);
[](CompileEngine self){
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
} // namespace relay
} // namespace tvm
......@@ -26,6 +26,7 @@
#include <tvm/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
......@@ -186,6 +187,12 @@ class CompileEngineNode : public Node {
* \return The result.
virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
* \brief Lower the external function using external codegen tools.
* \return The runtime moduels for each needed external codegen tool.
virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;
/*! \brief clear the cache. */
virtual void Clear() = 0;
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/object.h>
#include <fstream>
#include <sstream>
#include "codegen_c.h"
namespace tvm {
namespace relay {
namespace contrib {
* \brief An example codegen that is only used for quick prototyping and testing
* purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators.
class CodegenC : public ExprVisitor, public CodegenCBase {
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) {
out_.push_back({node->name_hint(), 0});
void VisitExpr_(const CallNode* call) final {
std::ostringstream macro_stream;
std::ostringstream decl_stream;
std::ostringstream buf_stream;
std::string func_name = ext_func_id_ + "_" + std::to_string(func_idx++);
// Make function declaration
macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", ";
if (IsOp(call, "add")) {
macro_stream << "+";
} else if (IsOp(call, "subtract")) {
macro_stream << "-";
} else if (IsOp(call, "multiply")) {
macro_stream << "*";
} else {
LOG(FATAL) << "Unrecognized op";
auto in_shape = GetShape(call->args[0]->checked_type());
for (size_t i = 0; i < in_shape.size(); ++i) {
macro_stream << ", " << in_shape[i];
macro_stream << ");";
// Make function call when visiting arguments
bool first = true;
decl_stream << func_name << "(";
for (size_t i = 0; i < call->args.size(); ++i) {
for (auto out : out_) {
if (!first) {
decl_stream << ", ";
first = false;
decl_stream << out.first;
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32))
<< "Only support single output tensor with float type";
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
decl_stream << ", " << out << ");";
// Update output buffer
out_.push_back({out, out_size});
* \brief Emit the source code that invokes C compiler compatible wrappers.
* \return The emitted code.
std::string JIT() {
// Write function macros
for (auto decl : func_decl_) {
code_stream_ << decl << "\n";
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
/*! \brief The function id that represents a C source function. */
std::string ext_func_id_ = "";
/*! \brief The index of a wrapped C function. */
int func_idx = 0;
/*! \brief The index of allocated buffers. */
int buf_idx_ = 0;
/*! \brief The arguments of a C compiler compatible function. */
std::vector<std::string> ext_func_args_;
/*! \brief The statements of a C compiler compatible function. */
std::vector<std::string> ext_func_body;
/*! \brief The declaration statements of a C compiler compatible function. */
std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */
std::vector<std::pair<std::string, int>> out_;
class CSourceCodegen : public CSourceModuleCodegenBase {
void GenCFunc(const Function& func) {
CHECK(func.defined()) << "Input error: expect a Relay function.";
// Record the external symbol for runtime lookup.
auto sid = GetExtSymbol(func);
auto builder = CodegenC(sid);
code_stream_ << builder.JIT();
runtime::Module CreateCSourceModule(const NodeRef& ref) override {
// Create headers
code_stream_ << "#include <cstdint>\n";
code_stream_ << "#include <iostream>\n";
code_stream_ << "#include <cstdlib>\n";
code_stream_ << "#include <stdio.h>\n";
code_stream_ << "#include <cstring>\n";
code_stream_ << "#include <tvm/runtime/c_runtime_api.h>\n";
code_stream_ << "#include <dlpack/dlpack.h>\n";
// Append some common macro for operator definition.
const char* operator_macro = R"op_macro(
#define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_) \
extern "C" void p_ID_(float* a, float* b, float* out) { \
for (int64_t i = 0; i < p_DIM1_; ++i) { \
out[i] = a[i] p_OP_ b[i]; \
} \
#define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_) \
extern "C" void p_ID_(float* a, float* b, float* out) { \
for (int64_t i = 0; i < p_DIM1_; ++i) { \
for (int64_t j = 0; j < p_DIM2_; ++j) { \
int64_t k = i * p_DIM2_ + j; \
out[k] = a[k] p_OP_ b[k]; \
} \
} \
code_stream_ << operator_macro << "\n\n";
if (ref->IsInstance<FunctionNode>()) {
} else if (ref->IsInstance<relay::ModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref);
for (const auto& it : mod->functions) {
} else {
LOG(FATAL) << "The input ref is expected to be a Relay function or module"
<< "\n";
// Create a CSourceModule
const auto* pf = runtime::Registry::Get("module.csource_module_create");
CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code_stream_.str(), "cc");
std::ostringstream code_stream_;
* \brief The external compiler/codegen tool. It takes a Relay expression/module and
* compile it into a runtime module.
* The external codegen tool should have been registered similiarly to LLVM,
* CUDA, etc, under TVM, so the generated code could be packed in a runtime
* module. This module simplifies code serialization and invocation.
runtime::Module CCompiler(const NodeRef& ref) {
CSourceCodegen csource;
return csource.CreateCSourceModule(ref);
} // namespace contrib
} // namespace relay
} // namespace tvm
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* \file src/relay/backend/contrib/codegen_c/codegen_c.h
* \brief The base class for external codegen tools.
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace tvm {
namespace relay {
namespace contrib {
class CSourceModuleCodegenBase {
CSourceModuleCodegenBase() = default;
* \brief Create a runtime module for the external library. For example, it
* could be a CSourceModule that can be directly compiled and linked together
* with a DSOModule, or a json style module that emitts a json artifact that
* is able to be executed by a customized json runtime.
* \param ref The ext_func Relay expression/module to be executed using extern ops.
* \return A runtime module.
virtual runtime::Module CreateCSourceModule(const NodeRef& ref) = 0;
* \brief Get the external symbol of the Relay function name.
* \param func The provided function.
* \return An external symbol.
std::string GetExtSymbol(const Function& func) const {
const auto name_node = FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value;
return ext_symbol;
// The base class to generate the declaration functions in C.
class CodegenCBase {
/*! \brief Print indents using spaces. */
void PrintIndents() {
for (int i = 0; i < indent_; i++) {
code_stream_ << ' ';
* \brief Enter a new scope.
void EnterScope() { indent_ += 2; }
* \brief Exit a scope.
void ExitScope() {
CHECK_GE(indent_, 2U) << "Wrong ident found.";
indent_ -= 2;
* \brief Gerenate C code for the external function.
* \param func_name The name of the external function.
* \param arg_cnt The expected number of arguments.
* \code
* // An example code for the generated C function.
* extern "C" void foo(TVMValue* value, int* type_code, int nargs) {
* if (nargs != 3) {
* printf("foo expects 3 args, but received %d\n", nargs);
* return 1;
* }
* DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
* DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
* DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
* foo_(static_cast<float*>(arg0->data),
* static_cast<float*>(arg1->data),
* static_cast<float*>(out->data));
* return 0;
* }
* \endcode
void GenerateBackendCFunc(const std::string& func_name, int arg_cnt) {
// Print signature
code_stream_ << "\n";
code_stream_ << "extern \"C\" int " << func_name;
code_stream_ << "(TVMValue* value, int* type_code, int nargs) {\n";
// Print guard
code_stream_ << "if (nargs != " << arg_cnt << "){\n";
code_stream_ << "printf(\"" << func_name << " expects " << arg_cnt
<< " arguments, but received %d\\n\", nargs);\n";
code_stream_ << "return 1;\n";
code_stream_ << "}\n";
// According to TVM's calling convention, the last one is output.
for (int i = 0; i < arg_cnt; i++) {
code_stream_ << "DLTensor* arg" << i << " = "
<< "static_cast<DLTensor*>(value[" << i << "].v_handle);\n";
// Generate the call.
code_stream_ << func_name << "_(";
for (int i = 0; i < arg_cnt - 1; i++) {
code_stream_ << "static_cast<float*>(arg" << i << "->data), ";
if (arg_cnt > 0) {
code_stream_ << "static_cast<float*>(arg" << arg_cnt - 1 << "->data)";
code_stream_ << ");\n\n";
code_stream_ << "return 0;\n";
code_stream_ << "}";
* \brief Emit the code for external runtime.
* \return The code string.
virtual std::string JIT() = 0;
* \brief Extract the shape from a Relay tensor type.
* \param type The provided type.
* \return The extracted shape in a list.
std::vector<int> GetShape(const Type& type) const {
const auto* ttype =<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
auto* val = ttype->shape[i].as<IntImm>();
return shape;
* \brief Check if a call has the provided name.
* \param call A Relay call node.
* \param op_name The name of the expected call.
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
bool IsOp(const CallNode* call, std::string op_name) const {
const auto* op_node = call-><OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
* \brief A common interface that is used by various external runtime to
* generate the wrapper to invoke external kernels.
* \param ext_func_id The unique id of an external function. It will be used
* during runtime to pick the correct external function.
* \param args The arguments used by the external function.
* \param buf_decl The declaration of temporary buffers that used to store the
* intermeidate of each external kernel.
* \param body The statements of the external function.
* \param out The name and id pairs for output.
* \return The emitted code string.
std::string JitImpl(std::string ext_func_id, std::vector<std::string> args,
std::vector<std::string> buf_decl, std::vector<std::string> body,
std::vector<std::pair<std::string, int>> out) {
// Create the signature. For example, it could be:
// extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {}
code_stream_ << "extern \"C\" void " << ext_func_id << "_(";
for (const auto& arg : args) {
code_stream_ << "float* " << arg << ", ";
code_stream_ << "float* out) {\n";
// Function body
for (auto decl : buf_decl) {
code_stream_ << decl << "\n";
code_stream_ << "\n";
for (auto stmt : body) {
code_stream_ << stmt << "\n";
// Copy output
CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support.";
code_stream_ << "std::memcpy(out, " << out[0].first << ", 4 * " << out[0].second << ");\n";
// Free buffers
for (size_t i = 0; i < buf_decl.size(); i++) {
code_stream_ << "std::free(buf_" << i << ");\n";
code_stream_ << "}\n";
// Create the wrapper to call the ext_func
this->GenerateBackendCFunc(ext_func_id, args.size() + 1 /* output */);
return code_stream_.str();
/*! \brief The external function source code stream. */
std::ostringstream code_stream_;
/*! \brief Indent of the source code. */
int indent_{0};
} // namespace contrib
} // namespace relay
} // namespace tvm
......@@ -24,6 +24,7 @@
#include <dmlc/any.h>
#include <dmlc/json.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
......@@ -55,6 +56,7 @@ using TargetsMap = std::unordered_map<int, Target>;
struct LoweredOutput {
std::string graph_json;
Map<std::string, Array<LoweredFunc> > lowered_funcs;
Array<tvm::runtime::Module> external_mods;
std::unordered_map<std::string, tvm::runtime::NDArray> params;
......@@ -226,6 +228,7 @@ class GraphRuntimeCodegen
ret.lowered_funcs.Set(kv.first, tmp);
ret.external_mods = compile_engine_->LowerExternalFunctions();
return ret;
......@@ -380,6 +383,25 @@ class GraphRuntimeCodegen
return fields;
std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op,
const std::string& op_name,
const std::string& func_name) {
std::vector<GraphNodeRef> inputs;
for (auto arg : op->args) {
auto res = VisitExpr(arg);
for (auto nr : res) {
auto node = GraphOpNode::make_node_ptr(op_name,
return AddNode(node, GetRef<Expr>(op));
std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
Expr expr = GetRef<Expr>(op);
Function func;
......@@ -398,17 +420,26 @@ class GraphRuntimeCodegen
<< "(i.e functions composed of fusable operator invocations)";
CHECK_GE(storage_device_map_.count(expr), 0);
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;
// Handle external function
if (!func->UseDefaultCompiler()) {
target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
CHECK(ext_func.defined()) << "External function is not defined.";
return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
CHECK_GE(storage_device_map_.count(expr), 0);
auto &device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value;
Target target;
// Normal Relay Function
if (targets_.size() == 1) {
// homogeneous execution.
for (auto kv : targets_) {
target = kv.second;
const auto& it = targets_.begin();
target = (*it).second;
} else {
// heterogeneous execution.
std::string call_dev_name;
......@@ -424,28 +455,17 @@ class GraphRuntimeCodegen
target = targets_[call_dev_type];
CCacheKey key = (*pf0)(func, target);
CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = {};
for (auto f : lowerd_func->funcs) {
for (auto f : lowered_func->funcs) {
std::vector<GraphNodeRef> inputs;
for (auto arg : op->args) {
auto res = VisitExpr(arg);
for (auto nr : res) {
auto& op_name = lowerd_func->func_name;
auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name),
return AddNode(node, expr);
return GraphAddCallNode(op,
std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
......@@ -470,7 +490,7 @@ class GraphRuntimeCodegen
return {};
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
throw std::invalid_argument("function not supported");
CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen";
return {};
std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
......@@ -628,7 +648,6 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
*rv = ret;
} else if (name == "get_param_by_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string key = args[0];
......@@ -639,6 +658,10 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.lowered_funcs;
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.external_mods;
} else {
return PackedFunc([](TVMArgs args, TVMRetValue* rv) {});
......@@ -37,21 +37,19 @@ namespace tvm {
namespace relay {
namespace vm {
static const char* kIsClosure = "IsClosure";
inline std::string GenerateName(const Function& func) {
size_t hash = StructuralHash()(func);
return std::string("lifted_name") + std::to_string(hash);
bool IsClosure(const Function& func) {
NodeRef res = FunctionGetAttr(func, kIsClosure);
NodeRef res = FunctionGetAttr(func, attr::kClosure);
const ir::IntImm* pval =<ir::IntImm>();
return pval && pval->value != 0;
Function MarkClosure(const Function& func) {
return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
return FunctionSetAttr(func, attr::kClosure, tvm::Integer(1));
/* The goal of this class is to lift out any nested functions into top-level
......@@ -157,13 +157,13 @@ FuncType FunctionNode::func_type_annotation() const {
bool FunctionNode::IsPrimitive() const {
NodeRef res = FunctionGetAttr(GetRef<Function>(this), "Primitive");
NodeRef res = FunctionGetAttr(GetRef<Function>(this), attr::kPrimitive);
const ir::IntImm* pval =<ir::IntImm>();
return pval && pval->value != 0;
Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
return FunctionSetAttr(GetRef<Function>(this), attr::kParams, parameters);
......@@ -173,7 +173,7 @@ TVM_REGISTER_API("relay._expr.FunctionSetParams")
tvm::Map<Var, Constant> FunctionNode::GetParams() const {
auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
auto node_ref = FunctionGetAttr(GetRef<Function>(this), attr::kParams);
return Downcast<tvm::Map<Var, Constant>>(node_ref);
......@@ -182,6 +182,12 @@ TVM_REGISTER_API("relay._expr.FunctionGetParams")
return func->GetParams();
bool FunctionNode::UseDefaultCompiler() const {
NodeRef res = FunctionGetAttr(GetRef<Function>(this), attr::kCompiler);
const ir::StringImm* pval =<ir::StringImm>();
return pval == nullptr || pval->value == "default";
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return NodeRef(); }
......@@ -239,7 +239,8 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque;
if (const OpNode* opnode = call-><OpNode>()) {
const OpNode* opnode = call-><OpNode>();
if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
} else {
this->Update(call->op, node, kOpaque);
......@@ -932,7 +933,7 @@ class FuseMutator : private ExprMutator {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(visitor.has_call));
func = FunctionSetAttr(func, attr::kPrimitive, tvm::Integer(visitor.has_call));
return CallNode::make(func, ginfo.arguments, Attrs());
......@@ -329,12 +329,10 @@ Module FunctionPassNode::operator()(const Module& mod,
return updated_mod;
// TODO(zhiics) Create an enum attribute for FunctionNode
// enum Attribute {kPrimitive, kSkipOptimization}
bool FunctionPassNode::SkipFunction(const Function& func) const {
NodeRef res = FunctionGetAttr(func, "SkipOptimization");
const ir::IntImm* pval =<ir::IntImm>();
return pval && pval->value != 0;
NodeRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization);
const ir::IntImm* pval =<ir::IntImm>();
return (pval && pval->value != 0) || (!func->UseDefaultCompiler());
Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* \file src/runtime/contrib/dnnl/
* \brief TVM compatible wrappers for dnnl kernels.
#include "dnnl_kernel.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
namespace tvm {
namespace runtime {
namespace contrib {
using namespace dnnl;
typedef struct {
void** data;
} DnnlPackedArgs;
// Read from memory, write to handle
inline void read_from_dnnl_memory(void* handle, const memory& mem) {
size_t bytes = mem.get_desc().get_size();
uint8_t* src = static_cast<uint8_t*>(mem.get_data_handle());
std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle));
extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
stream s(eng);
memory::dims conv2d_src_tz = {p_N_, p_C_, p_H_, p_W_};
memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_};
if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_};
memory::dims conv2d_bias_tz = {p_O_};
memory::dims conv2d_dst_tz = {p_N_, p_O_,
(p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_};
memory::dims conv2d_strides = {p_Sh_, p_Sw_};
memory::dims conv2d_padding = {p_Ph_, p_Pw_};
std::vector<float> conv2d_bias(p_O_, 0);
auto user_src_memory =
memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng,
auto conv2d_user_bias_memory =
memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng,;
auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any);
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw);
auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct,
conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md,
conv2d_strides, conv2d_padding, conv2d_padding);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng);
auto conv2d_src_memory = user_src_memory;
auto conv2d_weights_memory = user_weights_memory;
auto conv2d_dst_memory = memory(conv2d_prim_desc.dst_desc(), eng);
auto conv = convolution_forward(conv2d_prim_desc);
conv.execute(s, {{DNNL_ARG_SRC, conv2d_src_memory},
{DNNL_ARG_WEIGHTS, conv2d_weights_memory},
{DNNL_ARG_BIAS, conv2d_user_bias_memory},
{DNNL_ARG_DST, conv2d_dst_memory}});
read_from_dnnl_memory(out, conv2d_dst_memory);
extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_,
int p_I_, int p_O_) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
stream s(eng);
memory::dims data_tz = {p_B_, p_I_};
memory::dims weight_tz = {p_O_, p_I_};
memory::dims bias_tz = {p_O_};
memory::dims dst_tz = {p_B_, p_O_};
auto data_md = memory::desc{{data_tz}, dt::f32, tag::nc};
auto weight_md = memory::desc({{weight_tz}, dt::f32, tag::nc});
auto bias_md = memory::desc({{bias_tz}, dt::f32, tag::x});
auto dst_md = memory::desc({{dst_tz}, dt::f32, tag::nc});
std::vector<float> bias(p_O_, 0);
auto data_memory = memory(data_md, eng, data);
auto weight_memory = memory(weight_md, eng, weight);
auto bias_memory = memory(bias_md, eng,;
auto dst_memory = memory(dst_md, eng);
auto dense_desc = inner_product_forward::desc(
prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md);
auto dense_prim_desc = inner_product_forward::primitive_desc(dense_desc, eng);
assert(dst_md == dense_prim_desc.dst_desc());
auto dense = inner_product_forward(dense_prim_desc);
dense.execute(s, {{DNNL_ARG_SRC, data_memory},
{DNNL_ARG_WEIGHTS, weight_memory},
{DNNL_ARG_BIAS, bias_memory},
{DNNL_ARG_DST, dst_memory}});
read_from_dnnl_memory(out, dst_memory);
extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
int p_W_) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
stream s(eng);
memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_};
auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw};
auto data_memory = memory(data_md, eng, data);
auto dst_memory = memory(data_md, eng);
auto relu_desc = eltwise_forward::desc(prop_kind::forward_inference,
algorithm::eltwise_relu, data_md, 0);
auto relu_prim_desc = eltwise_forward::primitive_desc(relu_desc, eng);
assert(data_md == relu_prim_desc.dst_desc());
auto relu = eltwise_forward(relu_prim_desc);
relu.execute(s, {{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, dst_memory}});
read_from_dnnl_memory(out, dst_memory);
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
stream s(eng);
memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_};
auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw};
auto data_memory = memory(data_md, eng, data);
auto dst_memory = memory(data_md, eng);
auto bn_desc = batch_normalization_forward::desc(
prop_kind::forward_inference, data_md, p_E_,
normalization_flags::use_global_stats |
auto bn_prim_desc = batch_normalization_forward::primitive_desc(bn_desc, eng);
assert(data_md == bn_prim_desc.dst_desc());
float* weight = reinterpret_cast<float*>(malloc(sizeof(float) * 2 * p_C_));
memcpy(weight, gamma, sizeof(float) * p_C_);
memcpy(weight + p_C_, beta, sizeof(float) * p_C_);
auto weight_memory = memory(bn_prim_desc.weights_desc(), eng, weight);
auto mean_memory = memory(bn_prim_desc.mean_desc(), eng, mean);
auto variance_memory = memory(bn_prim_desc.variance_desc(), eng, variance);
auto bn = batch_normalization_forward(bn_prim_desc);
bn.execute(s, {{DNNL_ARG_SRC, data_memory},
{DNNL_ARG_DST, dst_memory},
{DNNL_ARG_SCALE_SHIFT, weight_memory},
{DNNL_ARG_MEAN, mean_memory},
{DNNL_ARG_VARIANCE, variance_memory}});
read_from_dnnl_memory(out, dst_memory);
extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_,
int p_C_, int p_H_, int p_W_) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
stream s(eng);
memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_};
auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw};
auto weight_md = memory::desc({{data_tz}, dt::f32, tag::nchw});
auto dst_md = memory::desc({{data_tz}, dt::f32, tag::nchw});
auto data_memory = memory(data_md, eng, data);
auto weight_memory = memory(weight_md, eng, weight);
auto dst_memory = memory(dst_md, eng);
auto add_desc =
binary::desc(algorithm::binary_add, data_md, weight_md, dst_md);
auto add_prim_desc = binary::primitive_desc(add_desc, eng);
assert(dst_md == add_prim_desc.dst_desc());
auto add = binary(add_prim_desc);
add.execute(s, {{DNNL_ARG_SRC_0, data_memory},
{DNNL_ARG_SRC_1, weight_memory},
{DNNL_ARG_DST, dst_memory}});
read_from_dnnl_memory(out, dst_memory);
} // namespace contrib
} // namespace runtime
} // namespace tvm
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* \file src/runtime/contrib/dnnl/dnnl_kernel.h
* \brief Use external dnnl library kernels.
#include <tvm/runtime/c_runtime_api.h>
#include "dnnl.hpp"
namespace tvm {
namespace runtime {
namespace contrib {
using namespace dnnl;
extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_);
extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_,
int p_O_);
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_);
} // namespace contrib
} // namespace runtime
} // namespace tvm
......@@ -126,6 +126,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
if (tkey == "c") continue;
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for graph partitioning."""
import os
import sys
import numpy as np
import pytest
import tvm
import tvm.relay.testing
import tvm.relay.transform
from tvm import relay
from tvm.contrib import util
def check_result(mod, map_inputs, out_shape, result, tol=1e-5):
if sys.platform == "win32":
print("Skip test on Windows for now")
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ =, "llvm")
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.join(test_dir, "..", "..", "..")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
tmp_path = util.tempdir()
lib_name = ''
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)
ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
for name, data in map_inputs.items():
rt_mod.set_input(name, data)
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def set_external_func_attr(func, compiler, ext_symbol):
func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.expr.StringImm(compiler))
func = func.set_attribute("ExternalSymbol", tvm.expr.StringImm(ext_symbol))
return func
def test_multi_node_subgraph():
x = relay.var('x', shape=(10, 10))
w0 = relay.var('w0', shape=(10, 10))
w1 = relay.var('w1', shape=(10, 10))
w2 = relay.var('w2', shape=(10, 10))
w3 = relay.var('w3', shape=(10, 10))
w4 = relay.var('w4', shape=(10, 10))
w5 = relay.var('w5', shape=(10, 10))
w6 = relay.var('w6', shape=(10, 10))
w7 = relay.var('w7', shape=(10, 10))
# subgraph0
x0 = relay.var('x0', shape=(10, 10))
w00 = relay.var('w00', shape=(10, 10))
w01 = relay.var('w01', shape=(10, 10))
w02 = relay.var('w02', shape=(10, 10))
z00 = relay.add(x0, w00)
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
subgraph0 = relay.Function([x0, w00, w01, w02], q00)
subgraph0 = set_external_func_attr(subgraph0, "ccompiler", "ccompiler_0")
call0 = relay.Call(subgraph0, [x, w0, w1, w2])
# subgraph1
x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
w11 = relay.var('w11', shape=(10, 10))
w12 = relay.var('w12', shape=(10, 10))
z10 = relay.add(x1, w10)
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
subgraph1 = relay.Function([x1, w10, w11, w12], q10)
subgraph1 = set_external_func_attr(subgraph1, "ccompiler", "ccompiler_1")
call1 = relay.Call(subgraph1, [x, w3, w4, w5])
# Other parts on TVM
z2 = relay.add(x, w6)
q2 = relay.subtract(z2, w7)
r = relay.concatenate((call0, call1, q2), axis=0)
f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r)
mod = relay.Module()
mod["main"] = f
mod = relay.transform.InferType()(mod)
x_data = np.random.rand(10, 10).astype('float32')
w_data = []
for _ in range(8):
w_data.append(np.random.rand(10, 10).astype('float32'))
map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
map_inputs["x"] = x_data
mod, map_inputs, (30, 10),
np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
((x_data + w_data[3]) - w_data[4]) * w_data[5],
x_data + w_data[6] - w_data[7]),
def test_extern_gcc_single_op():
x = relay.var('x', shape=(8, 8))
y = relay.var('y', shape=(8, 8))
x0 = relay.var('x0', shape=(8, 8))
y0 = relay.var('y0', shape=(8, 8))
z = x0 + y0
f = relay.Function([x0, y0], z)
f = set_external_func_attr(f, "ccompiler", "ccompiler_0")
call = relay.Call(f, [x, y])
mod = relay.Module.from_expr(call)
x_data = np.random.rand(8, 8).astype('float32')
y_data = np.random.rand(8, 8).astype('float32')
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
def test_extern_gcc():
x = relay.var('x', shape=(2, 2))
y = relay.var('y', shape=(2, 2))
# subgraph for mul
x0 = relay.var('x0', shape=(2, 2))
y0 = relay.var('y0', shape=(2, 2))
mul = x0 * y0
mul = relay.Function([x0, y0], mul)
mul = set_external_func_attr(mul, "ccompiler", "ccompiler_2")
call_mul = relay.Call(mul, [y, y])
# subgraph for add
x1 = relay.var('x1', shape=(2, 2))
y1 = relay.var('y1', shape=(2, 2))
add = x1 + y1
add = relay.Function([x1, y1], add)
add = set_external_func_attr(add, "ccompiler", "ccompiler_1")
call_add = relay.Call(add, [x, x])
# subgraph for sub
x2 = relay.var('x2', shape=(2, 2))
y2 = relay.var('y2', shape=(2, 2))
sub = x2 - y2
sub = relay.Function([x2, y2], sub)
sub = set_external_func_attr(sub, "ccompiler", "ccompiler_0")
call_sub = relay.Call(sub, [call_mul, call_add])
mod = relay.Module.from_expr(call_sub)
x_data = np.random.rand(2, 2).astype('float32')
y_data = np.random.rand(2, 2).astype('float32')
check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))
def test_extern_dnnl():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
dtype = 'float32'
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
data0 = relay.var('data0', shape=(ishape), dtype=dtype)
weight0 = relay.var('weight0', shape=(w1shape), dtype=dtype)
data1 = relay.var('data0', shape=(ishape), dtype=dtype)
weight1 = relay.var('weight0', shape=(w1shape), dtype=dtype)
weight2 = relay.var('weight1', shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data1,
kernel_size=(3, 3),
padding=(1, 1),
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
kernel_size=(3, 3),
padding=(1, 1),
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
f = relay.Function([data1, weight1, weight2], out)
ref_mod = relay.Module()
ref_mod['main'] = f
f = set_external_func_attr(f, "dnnl", "dnnl_0")
call = relay.Call(f, [data0, weight0, weight0])
mod = relay.Module.from_expr(call)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
w_data = np.random.uniform(0, 1, w1shape).astype(dtype)
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
ref_res = ref_ex.evaluate()(i_data, w_data, w_data)
check_result(mod, {"data0": i_data, "weight0": w_data},
(1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
if __name__ == "__main__":
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment