Unverified Commit 3a75b13d by Tianqi Chen Committed by GitHub

Misc refactor on graph runtime, layout node (#2557)

parent 5b8ff8d0
...@@ -173,7 +173,7 @@ class TVM_DLL DeviceAPI { ...@@ -173,7 +173,7 @@ class TVM_DLL DeviceAPI {
* \param allow_missing Whether allow missing * \param allow_missing Whether allow missing
* \return The corresponding device API. * \return The corresponding device API.
*/ */
TVM_DLL static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false); static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);
}; };
/*! \brief The device type bigger than this is RPC device */ /*! \brief The device type bigger than this is RPC device */
......
...@@ -73,9 +73,9 @@ class Layout : public NodeRef { ...@@ -73,9 +73,9 @@ class Layout : public NodeRef {
Layout(const std::string& name) { // NOLINT(*) Layout(const std::string& name) { // NOLINT(*)
node_ = make_node<LayoutNode>(); node_ = make_node<LayoutNode>();
std::vector<uint32_t> superdim_pos(kUniqueDim, -1); std::vector<int32_t> superdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_pos(kUniqueDim, -1); std::vector<int32_t> subdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_size(kUniqueDim, -1); std::vector<int32_t> subdim_size(kUniqueDim, -1);
std::vector<char> layout_simplified; std::vector<char> layout_simplified;
if (name != "__undef__") { // parse layout string if (name != "__undef__") { // parse layout string
......
...@@ -25,11 +25,11 @@ class GraphRuntimeDebug : public GraphRuntime { ...@@ -25,11 +25,11 @@ class GraphRuntimeDebug : public GraphRuntime {
* \return the elapsed time. * \return the elapsed time.
*/ */
double DebugRun(size_t index) { double DebugRun(size_t index) {
CHECK(index < op_execs().size()); CHECK(index < op_execs_.size());
TVMContext ctx = data_entry()[GetEntryId(index, 0)].operator->()->ctx; TVMContext ctx = data_entry_[entry_id(index, 0)]->ctx;
auto tbegin = std::chrono::high_resolution_clock::now(); auto tbegin = std::chrono::high_resolution_clock::now();
if (op_execs()[index]) { if (op_execs_[index]) {
op_execs()[index](); op_execs_[index]();
} }
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto tend = std::chrono::high_resolution_clock::now(); auto tend = std::chrono::high_resolution_clock::now();
...@@ -44,7 +44,7 @@ class GraphRuntimeDebug : public GraphRuntime { ...@@ -44,7 +44,7 @@ class GraphRuntimeDebug : public GraphRuntime {
* \param eid The Entry id of the op. * \param eid The Entry id of the op.
*/ */
NDArray GetOutputByLayer(int index, int eid) { NDArray GetOutputByLayer(int index, int eid) {
return data_entry()[GetEntryId(index, eid)]; return data_entry_[entry_id(index, eid)];
} }
/*! /*!
...@@ -81,15 +81,15 @@ class GraphRuntimeDebug : public GraphRuntime { ...@@ -81,15 +81,15 @@ class GraphRuntimeDebug : public GraphRuntime {
* \param data_out the node data. * \param data_out the node data.
*/ */
void DebugGetNodeOutput(int index, DLTensor* data_out) { void DebugGetNodeOutput(int index, DLTensor* data_out) {
CHECK_LT(static_cast<size_t>(index), op_execs().size()); CHECK_LT(static_cast<size_t>(index), op_execs_.size());
uint32_t eid = index; uint32_t eid = index;
for (size_t i = 0; i < op_execs().size(); ++i) { for (size_t i = 0; i < op_execs_.size(); ++i) {
if (op_execs()[i]) op_execs()[i](); if (op_execs_[i]) op_execs_[i]();
if (static_cast<int>(i) == index) break; if (static_cast<int>(i) == index) break;
} }
data_entry()[eid].CopyTo(data_out); data_entry_[eid].CopyTo(data_out);
} }
}; };
......
...@@ -126,30 +126,6 @@ class GraphRuntime : public ModuleNode { ...@@ -126,30 +126,6 @@ class GraphRuntime : public ModuleNode {
* \param param_blob A binary blob of parameter. * \param param_blob A binary blob of parameter.
*/ */
void LoadParams(const std::string& param_blob); void LoadParams(const std::string& param_blob);
/*!
* \brief Get the tensor vector pointer.
*/
std::vector<NDArray>& data_entry() {
return data_entry_;
}
/*!
* \brief Get the execution function pointer.
*/
std::vector<std::function<void()> >& op_execs() {
return op_execs_;
}
/*!
* \brief Get node entry index.
* \param nid Node id.
* \param index Index of the nodes.
*/
uint32_t GetEntryId(uint32_t nid, uint32_t index) const {
return node_row_ptr_[nid] + index;
}
/*! /*!
* \brief Get total number of nodes. * \brief Get total number of nodes.
* \return Total number of nodes. * \return Total number of nodes.
...@@ -163,7 +139,7 @@ class GraphRuntime : public ModuleNode { ...@@ -163,7 +139,7 @@ class GraphRuntime : public ModuleNode {
} }
private: protected:
// Memory pool entry. // Memory pool entry.
struct PoolEntry { struct PoolEntry {
size_t size; size_t size;
......
...@@ -18,11 +18,14 @@ def test_ewise(): ...@@ -18,11 +18,14 @@ def test_ewise():
shape = (20, 3) shape = (20, 3)
def test_apply(func, name, f_numpy, low, high): def test_apply(func, name, f_numpy, low, high, check_round=False):
B = func(A) B = func(A)
assert tuple(B.shape) == tuple(A.shape) assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name assert B.op.body[0].name == name
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
# avoid round check too close to boundary
if check_round:
a_np += ((np.fmod(a_np, 1) - 0.5) < 1e-6) * 1e-5
b_np = f_numpy(a_np) b_np = f_numpy(a_np)
def check_device(device): def check_device(device):
...@@ -48,7 +51,7 @@ def test_ewise(): ...@@ -48,7 +51,7 @@ def test_ewise():
test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100) test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100) test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100) test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
test_apply(topi.exp, "exp", np.exp, -1, 1) test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10) test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1) test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
......
...@@ -638,31 +638,23 @@ class InsnQueue : public BaseQueue { ...@@ -638,31 +638,23 @@ class InsnQueue : public BaseQueue {
static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_prev_dep),
static_cast<int>(c.mem.push_next_dep)); static_cast<int>(c.mem.push_next_dep));
// Count status in queues // Count status in queues
if (c.mem.opcode == VTA_OPCODE_LOAD || c.mem.opcode == VTA_OPCODE_STORE) { if (c.mem.opcode == VTA_OPCODE_STORE) {
if (c.mem.opcode == VTA_OPCODE_STORE) { CHECK(c.mem.pop_next_dep == false);
CHECK(c.mem.pop_next_dep == false); CHECK(c.mem.push_next_dep == false);
CHECK(c.mem.push_next_dep == false); if (c.mem.pop_prev_dep) g2s_queue--;
if (c.mem.pop_prev_dep) g2s_queue--; if (c.mem.push_prev_dep) s2g_queue++;
if (c.mem.push_prev_dep) s2g_queue++; } else if (c.mem.opcode == VTA_OPCODE_LOAD &&
} else if (c.mem.opcode == VTA_OPCODE_LOAD && (c.mem.memory_type == VTA_MEM_ID_INP ||
(c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT) ) {
c.mem.memory_type == VTA_MEM_ID_WGT) ) { CHECK(c.mem.pop_prev_dep == false);
CHECK(c.mem.pop_prev_dep == false); CHECK(c.mem.push_prev_dep == false);
CHECK(c.mem.push_prev_dep == false); if (c.mem.pop_next_dep) g2l_queue--;
if (c.mem.pop_next_dep) g2l_queue--; if (c.mem.push_next_dep) l2g_queue++;
if (c.mem.push_next_dep) l2g_queue++; } else {
} else { if (c.mem.pop_prev_dep) l2g_queue--;
if (c.mem.pop_prev_dep) l2g_queue--; if (c.mem.push_prev_dep) g2l_queue++;
if (c.mem.push_prev_dep) g2l_queue++; if (c.mem.pop_next_dep) s2g_queue--;
if (c.mem.pop_next_dep) s2g_queue--; if (c.mem.push_next_dep) g2s_queue++;
if (c.mem.push_next_dep) g2s_queue++;
}
} else if (c.mem.opcode == VTA_OPCODE_GEMM) {
// Print instruction field information
if (c.gemm.pop_prev_dep) l2g_queue--;
if (c.gemm.push_prev_dep) g2l_queue++;
if (c.gemm.pop_next_dep) s2g_queue--;
if (c.gemm.push_next_dep) g2s_queue++;
} }
printf("\tl2g_queue = %d, g2l_queue = %d\n", l2g_queue, g2l_queue); printf("\tl2g_queue = %d, g2l_queue = %d\n", l2g_queue, g2l_queue);
printf("\ts2g_queue = %d, g2s_queue = %d\n", s2g_queue, g2s_queue); printf("\ts2g_queue = %d, g2s_queue = %d\n", s2g_queue, g2s_queue);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment