Unverified Commit 5088a034 by Zhi Committed by GitHub

[Relay][BYOCG] Propagate constant to subgraphs (#5094)

* bind constant to subgraphs

* con -> constant
parent 53643bdb
......@@ -19,6 +19,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/object.h>
......@@ -40,7 +41,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
void VisitExpr_(const VarNode* node) {
void VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output;
......@@ -48,6 +49,55 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
out_.push_back(output);
}
void VisitExpr_(const ConstantNode* cn) final {
Constant constant = GetRef<Constant>(cn);
if (visited_.count(constant)) {
// Note this is for demostration purpose. ConstantNode doesn't necessarily
// belong to calls. We need to revisit this when tuples come into play.
out_.push_back(visited_[constant]);
return;
}
std::ostringstream decl_stream;
std::ostringstream buf_stream;
out_.clear();
Output output;
output.name = "const_" + std::to_string(const_idx_++);
out_.push_back(output);
visited_[constant] = output;
runtime::NDArray array = cn->data;
const auto& shape = array.Shape();
const DLTensor& dl_tensor = array.ToDLPack()->dl_tensor;
// Get the number of elements.
int64_t num_elems = 1;
for (auto i : shape) num_elems *= i;
const auto* type_node = cn->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
// Define a const buffer: float const_0[64] = {1.0, 2.0, ...};
//
// Technically, you may need: static float* const_0 = (float*)malloc(4 * 64)
// to avoid possible stack overflow.
buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {";
if (dtype == "float") {
float* p_flt = static_cast<float*>(dl_tensor.data);
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
if (num_elems) buf_stream << p_flt[num_elems - 1];
} else if (dtype == "int") {
int* p_flt = static_cast<int*>(dl_tensor.data);
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
if (num_elems) buf_stream << p_flt[num_elems - 1];
} else {
LOG(FATAL) << "Only float and int are supported for now.";
}
buf_stream << "};";
ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
}
void VisitExpr_(const CallNode* call) final {
std::ostringstream macro_stream;
std::ostringstream decl_stream;
......@@ -138,6 +188,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
int func_idx = 0;
/*! \brief The index of allocated buffers. */
int buf_idx_ = 0;
/*! \brief The index of global constants. */
int const_idx_ = 0;
/*! \brief The arguments of a C compiler compatible function. */
Array<Var> ext_func_args_;
/*! \brief The statements of a C compiler compatible function. */
......@@ -148,6 +200,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */
std::vector<Output> out_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
};
class CSourceCodegen : public CSourceModuleCodegenBase {
......
......@@ -197,7 +197,7 @@ class CodegenCBase {
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
bool IsOp(const CallNode* call, std::string op_name) const {
bool IsOp(const CallNode* call, const std::string& op_name) const {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
......@@ -218,7 +218,7 @@ class CodegenCBase {
*
* \return The emitted code string.
*/
std::string JitImpl(std::string ext_func_id, const Array<Var>& args,
std::string JitImpl(const std::string& ext_func_id, const Array<Var>& args,
const std::vector<std::string>& buf_decl,
const std::vector<std::string>& body,
const std::vector<Output>& out) {
......
......@@ -42,6 +42,8 @@
#include <utility>
#include <vector>
#include "../backend/utils.h"
namespace tvm {
namespace relay {
namespace partitioning {
......@@ -200,15 +202,21 @@ class Partitioner : public ExprMutator {
auto input = VisitExpr(call->args[0]);
Array<Var> params;
Array<Expr> args;
std::unordered_map<std::string, runtime::NDArray> params_bind;
// The subgraph may be merged so we need to update it again.
subgraph = GetSubgraph(GetRef<Call>(call));
CHECK(subgraph);
// Record the constants for propagation.
for (auto pair : subgraph->args) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
args.push_back(pair.second);
}
}
auto subgraph_func =
Function(params, input, call->checked_type_, {});
......@@ -223,6 +231,11 @@ class Partitioner : public ExprMutator {
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
subgraph_func = backend::BindParamsByName(subgraph_func, params_bind);
}
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph.
......
......@@ -634,6 +634,50 @@ def test_function_lifting_inline():
assert relay.analysis.alpha_equal(partitioned, ref_mod)
def test_constant_propagation():
ones = np.ones(shape=(8, 8), dtype="float32")
def expected():
mod = tvm.IRModule()
x = relay.const(ones)
y = relay.var("y", shape=(8, 8))
x0 = relay.const(ones)
y0 = relay.var("y0", shape=(8, 8))
add = x0 + y0
# Function that uses C compiler
func = relay.Function([y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [y])
log = relay.log(add_call)
main = relay.Function([y], log)
mod["main"] = main
return mod
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
f = relay.Function([x, y], log)
f = relay.build_module.bind_params_by_name(f, {"x": tvm.nd.array(ones)})
mod = tvm.IRModule()
mod["main"] = f
mod = WhiteListAnnotator(["add"], "ccompiler")(mod)
mod = transform.PartitionGraph()(mod)
expected_mod = expected()
assert relay.alpha_equal(mod, expected_mod)
y_data = np.random.rand(8, 8).astype('float32')
np_add = ones + y_data
check_result(mod, {"y": y_data}, (8, 8), np.log(np_add))
if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
......@@ -643,3 +687,4 @@ if __name__ == "__main__":
test_extern_dnnl_mobilenet()
test_function_lifting()
test_function_lifting_inline()
test_constant_propagation()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment