Unverified Commit 3a75b13d by Tianqi Chen Committed by GitHub

Misc refactor on graph runtime, layout node (#2557)

parent 5b8ff8d0
......@@ -173,7 +173,7 @@ class TVM_DLL DeviceAPI {
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
TVM_DLL static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);
static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);
};
/*! \brief The device type bigger than this is RPC device */
......
......@@ -73,9 +73,9 @@ class Layout : public NodeRef {
Layout(const std::string& name) { // NOLINT(*)
node_ = make_node<LayoutNode>();
std::vector<uint32_t> superdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_size(kUniqueDim, -1);
std::vector<int32_t> superdim_pos(kUniqueDim, -1);
std::vector<int32_t> subdim_pos(kUniqueDim, -1);
std::vector<int32_t> subdim_size(kUniqueDim, -1);
std::vector<char> layout_simplified;
if (name != "__undef__") { // parse layout string
......
......@@ -25,11 +25,11 @@ class GraphRuntimeDebug : public GraphRuntime {
* \return the elapsed time.
*/
double DebugRun(size_t index) {
CHECK(index < op_execs().size());
TVMContext ctx = data_entry()[GetEntryId(index, 0)].operator->()->ctx;
CHECK(index < op_execs_.size());
TVMContext ctx = data_entry_[entry_id(index, 0)]->ctx;
auto tbegin = std::chrono::high_resolution_clock::now();
if (op_execs()[index]) {
op_execs()[index]();
if (op_execs_[index]) {
op_execs_[index]();
}
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto tend = std::chrono::high_resolution_clock::now();
......@@ -44,7 +44,7 @@ class GraphRuntimeDebug : public GraphRuntime {
* \param eid The Entry id of the op.
*/
NDArray GetOutputByLayer(int index, int eid) {
return data_entry()[GetEntryId(index, eid)];
return data_entry_[entry_id(index, eid)];
}
/*!
......@@ -81,15 +81,15 @@ class GraphRuntimeDebug : public GraphRuntime {
* \param data_out the node data.
*/
void DebugGetNodeOutput(int index, DLTensor* data_out) {
CHECK_LT(static_cast<size_t>(index), op_execs().size());
CHECK_LT(static_cast<size_t>(index), op_execs_.size());
uint32_t eid = index;
for (size_t i = 0; i < op_execs().size(); ++i) {
if (op_execs()[i]) op_execs()[i]();
for (size_t i = 0; i < op_execs_.size(); ++i) {
if (op_execs_[i]) op_execs_[i]();
if (static_cast<int>(i) == index) break;
}
data_entry()[eid].CopyTo(data_out);
data_entry_[eid].CopyTo(data_out);
}
};
......
......@@ -126,30 +126,6 @@ class GraphRuntime : public ModuleNode {
* \param param_blob A binary blob of parameter.
*/
void LoadParams(const std::string& param_blob);
/*!
* \brief Get the tensor vector pointer.
*/
std::vector<NDArray>& data_entry() {
return data_entry_;
}
/*!
* \brief Get the execution function pointer.
*/
std::vector<std::function<void()> >& op_execs() {
return op_execs_;
}
/*!
* \brief Get node entry index.
* \param nid Node id.
* \param index Index of the nodes.
*/
uint32_t GetEntryId(uint32_t nid, uint32_t index) const {
return node_row_ptr_[nid] + index;
}
/*!
* \brief Get total number of nodes.
* \return Total number of nodes.
......@@ -163,7 +139,7 @@ class GraphRuntime : public ModuleNode {
}
private:
protected:
// Memory pool entry.
struct PoolEntry {
size_t size;
......
......@@ -18,11 +18,14 @@ def test_ewise():
shape = (20, 3)
def test_apply(func, name, f_numpy, low, high):
def test_apply(func, name, f_numpy, low, high, check_round=False):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
# avoid round check too close to boundary
if check_round:
a_np += ((np.fmod(a_np, 1) - 0.5) < 1e-6) * 1e-5
b_np = f_numpy(a_np)
def check_device(device):
......@@ -48,7 +51,7 @@ def test_ewise():
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
......
......@@ -638,31 +638,23 @@ class InsnQueue : public BaseQueue {
static_cast<int>(c.mem.push_prev_dep),
static_cast<int>(c.mem.push_next_dep));
// Count status in queues
if (c.mem.opcode == VTA_OPCODE_LOAD || c.mem.opcode == VTA_OPCODE_STORE) {
if (c.mem.opcode == VTA_OPCODE_STORE) {
CHECK(c.mem.pop_next_dep == false);
CHECK(c.mem.push_next_dep == false);
if (c.mem.pop_prev_dep) g2s_queue--;
if (c.mem.push_prev_dep) s2g_queue++;
} else if (c.mem.opcode == VTA_OPCODE_LOAD &&
(c.mem.memory_type == VTA_MEM_ID_INP ||
c.mem.memory_type == VTA_MEM_ID_WGT) ) {
CHECK(c.mem.pop_prev_dep == false);
CHECK(c.mem.push_prev_dep == false);
if (c.mem.pop_next_dep) g2l_queue--;
if (c.mem.push_next_dep) l2g_queue++;
} else {
if (c.mem.pop_prev_dep) l2g_queue--;
if (c.mem.push_prev_dep) g2l_queue++;
if (c.mem.pop_next_dep) s2g_queue--;
if (c.mem.push_next_dep) g2s_queue++;
}
} else if (c.mem.opcode == VTA_OPCODE_GEMM) {
// Print instruction field information
if (c.gemm.pop_prev_dep) l2g_queue--;
if (c.gemm.push_prev_dep) g2l_queue++;
if (c.gemm.pop_next_dep) s2g_queue--;
if (c.gemm.push_next_dep) g2s_queue++;
if (c.mem.opcode == VTA_OPCODE_STORE) {
CHECK(c.mem.pop_next_dep == false);
CHECK(c.mem.push_next_dep == false);
if (c.mem.pop_prev_dep) g2s_queue--;
if (c.mem.push_prev_dep) s2g_queue++;
} else if (c.mem.opcode == VTA_OPCODE_LOAD &&
(c.mem.memory_type == VTA_MEM_ID_INP ||
c.mem.memory_type == VTA_MEM_ID_WGT) ) {
CHECK(c.mem.pop_prev_dep == false);
CHECK(c.mem.push_prev_dep == false);
if (c.mem.pop_next_dep) g2l_queue--;
if (c.mem.push_next_dep) l2g_queue++;
} else {
if (c.mem.pop_prev_dep) l2g_queue--;
if (c.mem.push_prev_dep) g2l_queue++;
if (c.mem.pop_next_dep) s2g_queue--;
if (c.mem.push_next_dep) g2s_queue++;
}
printf("\tl2g_queue = %d, g2l_queue = %d\n", l2g_queue, g2l_queue);
printf("\ts2g_queue = %d, g2s_queue = %d\n", s2g_queue, g2s_queue);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment