Commit 137bf5f4 by hlu1 Committed by Tianqi Chen

[runtime] reduce set_input and set_input_zero_copy overhead (#3805)

parent ce031438
...@@ -31,11 +31,12 @@ ...@@ -31,11 +31,12 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <memory>
#include <numeric> #include <numeric>
#include <vector>
#include <string> #include <string>
#include <memory> #include <unordered_set>
#include <utility> #include <utility>
#include <vector>
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -78,6 +79,11 @@ void GraphRuntime::Init(const std::string& graph_json, ...@@ -78,6 +79,11 @@ void GraphRuntime::Init(const std::string& graph_json,
ctxs_ = ctxs; ctxs_ = ctxs;
this->SetupStorage(); this->SetupStorage();
this->SetupOpExecs(); this->SetupOpExecs();
for (size_t i = 0; i < input_nodes_.size(); i++) {
const uint32_t nid = input_nodes_[i];
std::string& name = nodes_[nid].name;
input_map_[name] = i;
}
} }
/*! /*!
* \brief Get the input index given the name of input. * \brief Get the input index given the name of input.
...@@ -85,11 +91,9 @@ void GraphRuntime::Init(const std::string& graph_json, ...@@ -85,11 +91,9 @@ void GraphRuntime::Init(const std::string& graph_json,
* \return The index of input. * \return The index of input.
*/ */
int GraphRuntime::GetInputIndex(const std::string& name) { int GraphRuntime::GetInputIndex(const std::string& name) {
for (size_t i = 0; i< input_nodes_.size(); ++i) { auto it = input_map_.find(name);
uint32_t nid = input_nodes_[i]; if (it != input_map_.end()) {
if (nodes_[nid].name == name) { return it->second;
return static_cast<int>(i);
}
} }
LOG(WARNING) << "Warning: cannot find \"" << name << "\" among input"; LOG(WARNING) << "Warning: cannot find \"" << name << "\" among input";
return -1; return -1;
...@@ -125,16 +129,8 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { ...@@ -125,16 +129,8 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) {
} }
// Update the data pointer for each argument of each op // Update the data pointer for each argument of each op
for (auto& op_arg : op_args_) { for (DLTensor* t : input_dltensors_[eid]) {
if (op_arg) { t->data = data_ref->data;
const auto it = op_arg->input_entry_ids.find(eid);
if (it != op_arg->input_entry_ids.end()) {
for (const auto i : it->second) {
DLTensor* t = static_cast<DLTensor*>(op_arg->arg_values[i].v_handle);
t->data = data_ref->data;
}
}
}
} }
} }
/*! /*!
...@@ -324,17 +320,21 @@ void GraphRuntime::SetupStorage() { ...@@ -324,17 +320,21 @@ void GraphRuntime::SetupStorage() {
void GraphRuntime::SetupOpExecs() { void GraphRuntime::SetupOpExecs() {
op_execs_.resize(this->GetNumOfNodes()); op_execs_.resize(this->GetNumOfNodes());
op_args_.resize(this->GetNumOfNodes()); input_dltensors_.resize(num_node_entries());
std::unordered_set<uint32_t> input_node_eids;
for (size_t i = 0; i < input_nodes_.size(); i++) {
uint32_t nid = input_nodes_[i];
input_node_eids.insert(entry_id(nid, 0));
}
// setup the array and requirements. // setup the array and requirements.
for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) {
const auto& inode = nodes_[nid]; const auto& inode = nodes_[nid];
if (inode.op_type == "null") continue; if (inode.op_type == "null") continue;
std::vector<DLTensor> args; std::vector<DLTensor> args;
std::vector<uint32_t> input_entry_ids;
for (const auto& e : inode.inputs) { for (const auto& e : inode.inputs) {
uint32_t eid = this->entry_id(e); uint32_t eid = this->entry_id(e);
args.push_back(*(data_entry_[eid].operator->())); args.push_back(*(data_entry_[eid].operator->()));
input_entry_ids.push_back(eid);
} }
for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { for (uint32_t index = 0; index < inode.param.num_outputs; ++index) {
uint32_t eid = this->entry_id(nid, index); uint32_t eid = this->entry_id(nid, index);
...@@ -342,16 +342,16 @@ void GraphRuntime::SetupOpExecs() { ...@@ -342,16 +342,16 @@ void GraphRuntime::SetupOpExecs() {
} }
CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op";
std::tie(op_execs_[nid], op_args_[nid]) = std::shared_ptr<OpArgs> op_args = nullptr;
std::tie(op_execs_[nid], op_args) =
CreateTVMOp(inode.param, args, inode.inputs.size()); CreateTVMOp(inode.param, args, inode.inputs.size());
auto& entry_to_input_pos = op_args_[nid]->input_entry_ids;
for (uint32_t i = 0; i < input_entry_ids.size(); ++i) { for (size_t i = 0; i < inode.inputs.size(); i++) {
const auto eid = input_entry_ids[i]; uint32_t eid = this->entry_id(inode.inputs[i]);
auto it = entry_to_input_pos.find(eid); // check if op input is model input
if (it == entry_to_input_pos.end()) { if (input_node_eids.count(eid) > 0) {
entry_to_input_pos.emplace(eid, std::vector<uint32_t>{i}); input_dltensors_[eid].push_back(
} else { static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
it->second.push_back(i);
} }
} }
} }
......
...@@ -70,7 +70,6 @@ struct TVMOpParam { ...@@ -70,7 +70,6 @@ struct TVMOpParam {
class GraphRuntime : public ModuleNode { class GraphRuntime : public ModuleNode {
struct OpArgs { struct OpArgs {
std::vector<DLTensor> args; std::vector<DLTensor> args;
std::unordered_map<uint32_t, std::vector<uint32_t> > input_entry_ids;
std::vector<TVMValue> arg_values; std::vector<TVMValue> arg_values;
std::vector<int> arg_tcodes; std::vector<int> arg_tcodes;
std::vector<int64_t> shape_data; std::vector<int64_t> shape_data;
...@@ -399,6 +398,10 @@ class GraphRuntime : public ModuleNode { ...@@ -399,6 +398,10 @@ class GraphRuntime : public ModuleNode {
std::vector<Node> nodes_; std::vector<Node> nodes_;
/*! \brief The argument nodes. */ /*! \brief The argument nodes. */
std::vector<uint32_t> input_nodes_; std::vector<uint32_t> input_nodes_;
/*! \brief Map of input names to input indices. */
std::unordered_map<std::string, uint32_t> input_map_;
/*! \brief Used for quick node input DLTensor* lookup given an input eid. */
std::vector<std::vector<DLTensor*>> input_dltensors_;
/*! \brief Used for quick entry indexing. */ /*! \brief Used for quick entry indexing. */
std::vector<uint32_t> node_row_ptr_; std::vector<uint32_t> node_row_ptr_;
/*! \brief Output entries. */ /*! \brief Output entries. */
...@@ -417,8 +420,6 @@ class GraphRuntime : public ModuleNode { ...@@ -417,8 +420,6 @@ class GraphRuntime : public ModuleNode {
std::vector<size_t> data_alignment_; std::vector<size_t> data_alignment_;
/*! \brief Operator on each node. */ /*! \brief Operator on each node. */
std::vector<std::function<void()> > op_execs_; std::vector<std::function<void()> > op_execs_;
/*! \brief Arg info of TVM ops */
std::vector<std::shared_ptr<OpArgs> > op_args_;
}; };
std::vector<TVMContext> GetAllContext(const TVMArgs& args); std::vector<TVMContext> GetAllContext(const TVMArgs& args);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment