Commit afd4b3e4 by Yinghai Lu Committed by Tianqi Chen

[Runtime] Enable set_input_zero_copy in GraphRuntime (#3416)

* Enable set_input_zero_copy in GraphRuntime

* Fix LoadParams

* Fix

* lint

* Fix remote context issue

* Fix

* Remove LOG

* Remove unused variables

* Add tests

* works

* More test scenarios

* make it simpler

* Remove unnecessary changes

* Address comments

* More comments

* Address comments

* Fix build
parent 9fad94cc
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -23,6 +23,7 @@
*/
#include "graph_runtime.h"
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
......@@ -38,6 +39,13 @@
namespace tvm {
namespace runtime {
namespace details {
inline size_t GetDataAlignment(const DLTensor& arr) {
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
if (align < kAllocAlignment) return kAllocAlignment;
return align;
}
} // namespace details
/*!
* \brief Run all the operations one by one.
......@@ -97,6 +105,39 @@ void GraphRuntime::SetInput(int index, DLTensor* data_in) {
data_entry_[eid].CopyFrom(data_in);
}
/*!
* \brief set index-th input to the graph without copying the data.
* \param index The input index.
* \param data_ref The input data that is referred.
*/
void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) {
CHECK_LT(static_cast<size_t>(index), input_nodes_.size());
uint32_t eid = this->entry_id(input_nodes_[index], 0);
const DLTensor* old_t = data_entry_[eid].operator->();
// check the consistency of input
CHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref));
CHECK_EQ(reinterpret_cast<size_t>(data_ref->data) % kAllocAlignment, 0);
CHECK_EQ(old_t->ndim, static_cast<size_t>(data_ref->ndim));
CHECK_EQ(old_t->ctx.device_type, data_ref->ctx.device_type);
CHECK_EQ(old_t->ctx.device_id, data_ref->ctx.device_id);
for (auto i = 0; i < data_ref->ndim; ++i) {
CHECK_EQ(old_t->shape[i], data_ref->shape[i]);
}
// Update the data pointer for each argument of each op
for (auto& op_arg : op_args_) {
if (op_arg) {
const auto it = op_arg->input_entry_ids.find(eid);
if (it != op_arg->input_entry_ids.end()) {
for (const auto i : it->second) {
DLTensor* t = static_cast<DLTensor*>(op_arg->arg_values[i].v_handle);
t->data = data_ref->data;
}
}
}
}
}
/*!
* \brief Get the number of outputs
*
* \return The number of outputs from graph.
......@@ -184,7 +225,7 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
}
}
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
......@@ -206,6 +247,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK_EQ(data_entry_[eid].use_count(), 1);
data_entry_[eid] = other.GetInput(GetInputIndex(names[i]));
CHECK_GT(data_entry_[eid].use_count(), 1);
const DLTensor* tmp = data_entry_[eid].operator->();
data_alignment_[eid] = details::GetDataAlignment(*tmp);
}
this->SetupOpExecs();
}
......@@ -268,23 +311,30 @@ void GraphRuntime::SetupStorage() {
// memory assignment for each node entry. The allocated memory on each device
// is mapped to this pool.
data_entry_.resize(num_node_entries());
data_alignment_.resize(num_node_entries());
for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = attrs_.storage_id[i];
CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
data_entry_[i] =
storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
const DLTensor* tmp = data_entry_[i].operator->();
data_alignment_[i] = details::GetDataAlignment(*tmp);
}
}
void GraphRuntime::SetupOpExecs() {
op_execs_.resize(this->GetNumOfNodes());
op_args_.resize(this->GetNumOfNodes());
// setup the array and requirements.
for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) {
const auto& inode = nodes_[nid];
if (inode.op_type == "null") continue;
std::vector<DLTensor> args;
std::vector<uint32_t> input_entry_ids;
for (const auto& e : inode.inputs) {
args.push_back(*(data_entry_[this->entry_id(e)].operator->()));
uint32_t eid = this->entry_id(e);
args.push_back(*(data_entry_[eid].operator->()));
input_entry_ids.push_back(eid);
}
for (uint32_t index = 0; index < inode.param.num_outputs; ++index) {
uint32_t eid = this->entry_id(nid, index);
......@@ -292,29 +342,34 @@ void GraphRuntime::SetupOpExecs() {
}
CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op";
op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size());
std::tie(op_execs_[nid], op_args_[nid]) =
CreateTVMOp(inode.param, args, inode.inputs.size());
auto& entry_to_input_pos = op_args_[nid]->input_entry_ids;
for (uint32_t i = 0; i < input_entry_ids.size(); ++i) {
const auto eid = input_entry_ids[i];
auto it = entry_to_input_pos.find(eid);
if (it == entry_to_input_pos.end()) {
entry_to_input_pos.emplace(eid, std::vector<uint32_t>{i});
} else {
it->second.push_back(i);
}
}
}
}
std::function<void()> GraphRuntime::CreateTVMOp(
std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRuntime::CreateTVMOp(
const TVMOpParam& param,
const std::vector<DLTensor>& args,
size_t num_inputs) {
struct OpArgs {
std::vector<DLTensor> args;
std::vector<TVMValue> arg_values;
std::vector<int> arg_tcodes;
std::vector<int64_t> shape_data;
};
std::shared_ptr<OpArgs> arg_ptr = std::make_shared<OpArgs>();
std::shared_ptr<GraphRuntime::OpArgs> arg_ptr = std::make_shared<GraphRuntime::OpArgs>();
// setup address.
arg_ptr->args = std::move(args);
arg_ptr->args = args;
if (param.flatten_data) {
arg_ptr->shape_data.resize(arg_ptr->args.size());
}
for (size_t i = 0; i < arg_ptr->args.size(); ++i) {
TVMValue v;
DLTensor* t = &(arg_ptr->args[i]);
DLTensor* t = &arg_ptr->args[i];
v.v_handle = t;
arg_ptr->arg_values.push_back(v);
arg_ptr->arg_tcodes.push_back(kArrayHandle);
......@@ -327,7 +382,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
}
if (param.func_name == "__nop") {
return [](){};
return {[](){}, arg_ptr};
} else if (param.func_name == "__copy") {
// Perform cross device data copy.
// Directly copy data from the input to the output.
......@@ -336,7 +391,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle);
TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr));
};
return fexec;
return {fexec, arg_ptr};
}
// Get compiled function from the module that contains both host and device
......@@ -351,7 +406,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
static_cast<int>(arg_ptr->arg_values.size()));
pf.CallPacked(targs, &rv);
};
return fexec;
return {fexec, arg_ptr};
}
PackedFunc GraphRuntime::GetFunction(
......@@ -367,14 +422,23 @@ PackedFunc GraphRuntime::GetFunction(
this->SetInput(args[0], args[1]);
}
});
} else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) {
int in_idx = this->GetInputIndex(args[0]);
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
} else {
this->SetInputZeroCopy(args[0], args[1]);
}
});
} else if (name == "get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args.num_args == 2) {
this->CopyOutputTo(args[0], args[1]);
} else {
*rv = this->GetOutput(args[0]);
}
});
if (args.num_args == 2) {
this->CopyOutputTo(args[0], args[1]);
} else {
*rv = this->GetOutput(args[0]);
}
});
} else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = 0;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -34,6 +34,7 @@
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include <string>
......@@ -67,6 +68,14 @@ struct TVMOpParam {
* TVM runtime PackedFunc API.
*/
class GraphRuntime : public ModuleNode {
struct OpArgs {
std::vector<DLTensor> args;
std::unordered_map<uint32_t, std::vector<uint32_t> > input_entry_ids;
std::vector<TVMValue> arg_values;
std::vector<int> arg_tcodes;
std::vector<int64_t> shape_data;
};
public:
/*!
* \brief Get member function to front-end
......@@ -112,6 +121,12 @@ class GraphRuntime : public ModuleNode {
*/
void SetInput(int index, DLTensor* data_in);
/*!
* \brief set index-th input to the graph without copying the data
* \param index The input index.
* \param data_ref The input data that is referred.
*/
void SetInputZeroCopy(int index, DLTensor* data_ref);
/*!
* \brief Get the number of outputs
*
* \return The number of outputs from graph.
......@@ -365,9 +380,9 @@ class GraphRuntime : public ModuleNode {
* \param num_inputs Number of inputs.
* \return The created executor.
*/
std::function<void()> CreateTVMOp(const TVMOpParam& attrs,
const std::vector<DLTensor>& args,
size_t num_inputs);
std::pair<std::function<void()>, std::shared_ptr<OpArgs> > CreateTVMOp(
const TVMOpParam& attrs, const std::vector<DLTensor>& args,
size_t num_inputs);
// Get node entry index.
uint32_t entry_id(uint32_t nid, uint32_t index) const {
return node_row_ptr_[nid] + index;
......@@ -398,8 +413,12 @@ class GraphRuntime : public ModuleNode {
std::vector<NDArray> storage_pool_;
/*! \brief Data entry of each node. */
std::vector<NDArray> data_entry_;
/*! \brief Data alignment of each node. */
std::vector<size_t> data_alignment_;
/*! \brief Operator on each node. */
std::vector<std::function<void()> > op_execs_;
/*! \brief Arg info of TVM ops */
std::vector<std::shared_ptr<OpArgs> > op_args_;
};
std::vector<TVMContext> GetAllContext(const TVMArgs& args);
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
......@@ -85,18 +85,41 @@ TEST(Relay, BuildModule) {
auto ctx = A->ctx;
auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create");
tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id);
auto set_input_f = run_mod.GetFunction("set_input", false);
auto set_input_f = run_mod.GetFunction("set_input_zero_copy", false);
auto run_f = run_mod.GetFunction("run", false);
auto get_output_f = run_mod.GetFunction("get_output", false);
set_input_f("a", A);
set_input_f("b", B);
set_input_f("c", C);
set_input_f("a", &A.ToDLPack()->dl_tensor);
set_input_f("b", &B.ToDLPack()->dl_tensor);
set_input_f("c", &C.ToDLPack()->dl_tensor);
run_f();
tvm::runtime::NDArray Y = get_output_f(0);
auto pY = (float*)Y.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
}
// mutate the input a bit and run it again
for (int i = 0; i < 6; ++i) {
pB[i] = i + 3;
}
run_f();
tvm::runtime::NDArray Y2 = get_output_f(0);
auto pY2 = (float*)Y2.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY2[i] - (i + (i + 3) + (i + 2))), 1e-4);
}
// attach a different input and run it again
auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pC2 = (float*)C2.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
pC2[i] = i + 4;
}
set_input_f("c", &C2.ToDLPack()->dl_tensor);
run_f();
tvm::runtime::NDArray Y3 = get_output_f(0);
auto pY3 = (float*)Y3.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY3[i] - (i + (i + 3) + (i + 4))), 1e-4);
}
}
int main(int argc, char ** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment