Unverified Commit a2429c1f by Jon Soifer Committed by GitHub

[Relay][External Codegen] Support data types for CSourceModuleCodegen args and output (#4934)

* Support int args and no extra buffers

* Fixes

* remove testing code

* fix style

* more style

* use const args

* style

Co-authored-by: Jon Soifer <jonso@microsoft.com>
parent 87c20bb2
...@@ -41,9 +41,11 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -41,9 +41,11 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) { void VisitExpr_(const VarNode* node) {
ext_func_args_.push_back(node->name_hint()); ext_func_args_.push_back(GetRef<Var>(node));
out_.clear(); out_.clear();
out_.push_back({node->name_hint(), 0}); Output output;
output.name = node->name_hint();
out_.push_back(output);
} }
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const CallNode* call) final {
...@@ -70,6 +72,12 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -70,6 +72,12 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
for (size_t i = 0; i < in_shape.size(); ++i) { for (size_t i = 0; i < in_shape.size(); ++i) {
macro_stream << ", " << in_shape[i]; macro_stream << ", " << in_shape[i];
} }
const auto* type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
macro_stream << ", " << dtype;
macro_stream << ");"; macro_stream << ");";
func_decl_.push_back(macro_stream.str()); func_decl_.push_back(macro_stream.str());
...@@ -83,20 +91,18 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -83,20 +91,18 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
decl_stream << ", "; decl_stream << ", ";
} }
first = false; first = false;
decl_stream << out.first; decl_stream << out.name;
} }
} }
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32))
<< "Only support single output tensor with float type";
std::string out = "buf_" + std::to_string(buf_idx_++); std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type()); auto out_shape = GetShape(call->checked_type());
int out_size = 1; int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) { for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i]; out_size *= out_shape[i];
} }
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; buf_stream << dtype << "* " << out <<
" = (" << dtype << "*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str()); buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out << ");"; decl_stream << ", " << out << ");";
...@@ -104,7 +110,12 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -104,7 +110,12 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
// Update output buffer // Update output buffer
out_.clear(); out_.clear();
out_.push_back({out, out_size}); Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
} }
/*! /*!
...@@ -128,7 +139,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -128,7 +139,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
/*! \brief The index of allocated buffers. */ /*! \brief The index of allocated buffers. */
int buf_idx_ = 0; int buf_idx_ = 0;
/*! \brief The arguments of a C compiler compatible function. */ /*! \brief The arguments of a C compiler compatible function. */
std::vector<std::string> ext_func_args_; Array<Var> ext_func_args_;
/*! \brief The statements of a C compiler compatible function. */ /*! \brief The statements of a C compiler compatible function. */
std::vector<std::string> ext_func_body; std::vector<std::string> ext_func_body;
/*! \brief The declaration statements of a C compiler compatible function. */ /*! \brief The declaration statements of a C compiler compatible function. */
...@@ -136,7 +147,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ...@@ -136,7 +147,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
/*! \brief The declaration statements of buffers. */ /*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_; std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */ /*! \brief The name and index pairs for output. */
std::vector<std::pair<std::string, int>> out_; std::vector<Output> out_;
}; };
class CSourceCodegen : public CSourceModuleCodegenBase { class CSourceCodegen : public CSourceModuleCodegenBase {
...@@ -161,15 +172,15 @@ class CSourceCodegen : public CSourceModuleCodegenBase { ...@@ -161,15 +172,15 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
// Append some common macro for operator definition. // Append some common macro for operator definition.
const char* operator_macro = R"op_macro( const char* operator_macro = R"op_macro(
#define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_) \ #define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_, p_DTYPE) \
extern "C" void p_ID_(float* a, float* b, float* out) { \ extern "C" void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \
for (int64_t i = 0; i < p_DIM1_; ++i) { \ for (int64_t i = 0; i < p_DIM1_; ++i) { \
out[i] = a[i] p_OP_ b[i]; \ out[i] = a[i] p_OP_ b[i]; \
} \ } \
} }
#define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_) \ #define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_, p_DTYPE) \
extern "C" void p_ID_(float* a, float* b, float* out) { \ extern "C" void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \
for (int64_t i = 0; i < p_DIM1_; ++i) { \ for (int64_t i = 0; i < p_DIM1_; ++i) { \
for (int64_t j = 0; j < p_DIM2_; ++j) { \ for (int64_t j = 0; j < p_DIM2_; ++j) { \
int64_t k = i * p_DIM2_ + j; \ int64_t k = i * p_DIM2_ + j; \
......
...@@ -35,6 +35,13 @@ namespace tvm { ...@@ -35,6 +35,13 @@ namespace tvm {
namespace relay { namespace relay {
namespace contrib { namespace contrib {
struct Output {
std::string name;
std::string dtype;
int size;
bool need_copy;
};
class CSourceModuleCodegenBase { class CSourceModuleCodegenBase {
public: public:
CSourceModuleCodegenBase() = default; CSourceModuleCodegenBase() = default;
...@@ -98,7 +105,7 @@ class CodegenCBase { ...@@ -98,7 +105,7 @@ class CodegenCBase {
* \brief Gerenate C code for the external function. * \brief Gerenate C code for the external function.
* *
* \param func_name The name of the external function. * \param func_name The name of the external function.
* \param arg_cnt The expected number of arguments. * \param args arguments to the external function.
* *
* \code * \code
* *
...@@ -116,16 +123,18 @@ class CodegenCBase { ...@@ -116,16 +123,18 @@ class CodegenCBase {
* *
* \endcode * \endcode
*/ */
void GenerateBackendCFunc(const std::string& func_name, int arg_cnt) { void GenerateBackendCFunc(const std::string& func_name,
const Array<Var>& args,
const Output& out) {
// Print signature // Print signature
code_stream_ << "\n"; code_stream_ << "\n";
code_stream_ << "extern \"C\" int " << func_name << "_wrapper_("; code_stream_ << "extern \"C\" int " << func_name << "_wrapper_(";
for (int i = 0; i < arg_cnt - 1; i++) { for (size_t i = 0; i < args.size(); i++) {
code_stream_ << "DLTensor* arg" << i << ",\n"; code_stream_ << "DLTensor* arg" << i << ",\n";
code_stream_ << "\t"; code_stream_ << "\t";
} }
if (arg_cnt > 0) { if (args.size() > 0) {
code_stream_ << "DLTensor* arg" << arg_cnt - 1 << ") {\n"; code_stream_ << "DLTensor* arg" << args.size() << ") {\n";
} }
EnterScope(); EnterScope();
...@@ -133,12 +142,13 @@ class CodegenCBase { ...@@ -133,12 +142,13 @@ class CodegenCBase {
// Generate the internal call. // Generate the internal call.
PrintIndents(); PrintIndents();
code_stream_ << func_name << "_("; code_stream_ << func_name << "_(";
for (int i = 0; i < arg_cnt - 1; i++) { for (size_t i = 0; i < args.size(); i++) {
code_stream_ << "static_cast<float*>(arg" << i << "->data),\n"; const auto& dtype_str = GetDtypeString(args[i]);
code_stream_ << "static_cast<" << dtype_str << "*>(arg" << i << "->data),\n";
PrintIndents(); PrintIndents();
} }
if (arg_cnt > 0) { if (args.size() > 0) {
code_stream_ << "static_cast<float*>(arg" << arg_cnt - 1 << "->data)"; code_stream_ << "static_cast<" << out.dtype << "*>(arg" << args.size() << "->data)";
} }
code_stream_ << ");\n"; code_stream_ << ");\n";
PrintIndents(); PrintIndents();
...@@ -207,17 +217,21 @@ class CodegenCBase { ...@@ -207,17 +217,21 @@ class CodegenCBase {
* *
* \return The emitted code string. * \return The emitted code string.
*/ */
std::string JitImpl(std::string ext_func_id, std::vector<std::string> args, std::string JitImpl(std::string ext_func_id, const Array<Var>& args,
std::vector<std::string> buf_decl, std::vector<std::string> body, const std::vector<std::string>& buf_decl,
std::vector<std::pair<std::string, int>> out) { const std::vector<std::string>& body,
const std::vector<Output>& out) {
// Create the signature. For example, it could be: // Create the signature. For example, it could be:
// extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {} // extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {}
code_stream_ << "extern \"C\" void " << ext_func_id << "_("; code_stream_ << "extern \"C\" void " << ext_func_id << "_(";
CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support.";
for (const auto& arg : args) { for (const auto& arg : args) {
code_stream_ << "float* " << arg << ", "; const auto& dtype_str = GetDtypeString(arg);
code_stream_ << dtype_str << "* " << arg->name_hint() << ", ";
} }
code_stream_ << "float* out) {\n"; code_stream_ << out[0].dtype << "* out) {\n";
this->EnterScope(); this->EnterScope();
// Function body // Function body
...@@ -232,24 +246,60 @@ class CodegenCBase { ...@@ -232,24 +246,60 @@ class CodegenCBase {
} }
// Copy output // Copy output
CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support."; if (out[0].need_copy) {
this->PrintIndents(); this->PrintIndents();
code_stream_ << "std::memcpy(out, " << out[0].first << ", 4 * " << out[0].second << ");\n"; code_stream_ << "std::memcpy(out, " << out[0].name << ", 4 * " << out[0].size << ");\n";
// Free buffers // Free buffers
for (size_t i = 0; i < buf_decl.size(); i++) { for (size_t i = 0; i < buf_decl.size(); i++) {
this->PrintIndents(); this->PrintIndents();
code_stream_ << "std::free(buf_" << i << ");\n"; code_stream_ << "std::free(buf_" << i << ");\n";
} }
}
this->ExitScope(); this->ExitScope();
code_stream_ << "}\n"; code_stream_ << "}\n";
// Create the wrapper to call the ext_func // Create the wrapper to call the ext_func
this->GenerateBackendCFunc(ext_func_id, args.size() + 1 /* output */); this->GenerateBackendCFunc(ext_func_id, args, out[0]);
return code_stream_.str(); return code_stream_.str();
} }
/*!
* \brief Returns dtype string
*
* \param var Var to get the dtype of
*
* \return The dtype string.
*/
std::string GetDtypeString(const Var& var) {
auto ttype = var->checked_type().as<TensorTypeNode>();
CHECK(ttype) << "Expect TensorTypeNode";
return GetDtypeString(ttype);
}
/*!
* \brief Returns dtype string
*
* \param ttype TensorTypeNode* to get the dtype of
*
* \return The dtype string.
*/
std::string GetDtypeString(const TensorTypeNode* ttype) {
std::string dtype;
if (runtime::TypeMatch(ttype->dtype, kDLFloat, 32)) {
dtype = "float";
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) {
dtype = "int";
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) {
dtype = "int64_t";
} else {
LOG(FATAL) << "Unsupported dtype " << ttype->dtype;
}
return dtype;
}
/*! \brief The external function source code stream. */ /*! \brief The external function source code stream. */
std::ostringstream code_stream_; std::ostringstream code_stream_;
......
...@@ -45,9 +45,11 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -45,9 +45,11 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; } explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) final { void VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(node->name_hint()); ext_func_args_.push_back(GetRef<Var>(node));
out_.clear(); out_.clear();
out_.push_back({node->name_hint(), 0}); Output output;
output.name = node->name_hint();
out_.push_back(output);
} }
void VisitExpr_(const TupleGetItemNode* op) final { void VisitExpr_(const TupleGetItemNode* op) final {
...@@ -90,14 +92,14 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -90,14 +92,14 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
decl_stream << ", "; decl_stream << ", ";
} }
first = false; first = false;
decl_stream << out.first; decl_stream << out.name;
} }
} }
// Analyze the output buffer // Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>(); auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32)) CHECK(type_node);
<< "Only support single output tensor with float type"; const auto& dtype = GetDtypeString(type_node);
std::string out = "buf_" + std::to_string(buf_idx_++); std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type()); auto out_shape = GetShape(call->checked_type());
int out_size = 1; int out_size = 1;
...@@ -118,7 +120,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -118,7 +120,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
// Update output buffer // Update output buffer
out_.clear(); out_.clear();
out_.push_back({out, out_size}); Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
} }
std::string JIT(void) { std::string JIT(void) {
...@@ -213,13 +220,13 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -213,13 +220,13 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
*/ */
int buf_idx_{0}; int buf_idx_{0};
/*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
std::vector<std::string> ext_func_args_; Array<Var> ext_func_args_;
/*! \brief statement of the function that will be compiled using DNNL kernels. */ /*! \brief statement of the function that will be compiled using DNNL kernels. */
std::vector<std::string> ext_func_body; std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */ /*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_; std::vector<std::string> buf_decl_;
/*! \brief The name of the the outputs. */ /*! \brief The name of the the outputs. */
std::vector<std::pair<std::string, int>> out_; std::vector<Output> out_;
}; };
/*! /*!
......
...@@ -161,6 +161,23 @@ def test_extern_gcc_single_op(): ...@@ -161,6 +161,23 @@ def test_extern_gcc_single_op():
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
def test_extern_gcc_single_op_int():
x = relay.var('x', shape=(8, 8), dtype="int32")
y = relay.var('y', shape=(8, 8), dtype="int32")
x0 = relay.var('x0', shape=(8, 8), dtype="int32")
y0 = relay.var('y0', shape=(8, 8), dtype="int32")
z = x0 + y0
f = relay.Function([x0, y0], z)
f = set_external_func_attr(f, "ccompiler", "ccompiler_0")
call = relay.Call(f, [x, y])
mod = tvm.IRModule.from_expr(call)
x_data = np.random.rand(8, 8).astype('int32')
y_data = np.random.rand(8, 8).astype('int32')
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
def test_extern_gcc(): def test_extern_gcc():
x = relay.var('x', shape=(2, 2)) x = relay.var('x', shape=(2, 2))
y = relay.var('y', shape=(2, 2)) y = relay.var('y', shape=(2, 2))
...@@ -242,5 +259,6 @@ def test_extern_dnnl(): ...@@ -242,5 +259,6 @@ def test_extern_dnnl():
if __name__ == "__main__": if __name__ == "__main__":
test_multi_node_subgraph() test_multi_node_subgraph()
test_extern_gcc_single_op() test_extern_gcc_single_op()
test_extern_gcc_single_op_int()
test_extern_gcc() test_extern_gcc()
test_extern_dnnl() test_extern_dnnl()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment