Commit 2a898181 by hlu1 Committed by Tianqi Chen

[Graph Runtime] Run_individual for benchmarking individual layers (#2569)

parent 1fa30a84
...@@ -89,6 +89,7 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -89,6 +89,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
self._dump_path = None self._dump_path = None
self._debug_run = module["debug_run"] self._debug_run = module["debug_run"]
self._get_output_by_layer = module["get_output_by_layer"] self._get_output_by_layer = module["get_output_by_layer"]
self._run_individual = module["run_individual"]
graph_runtime.GraphModule.__init__(self, module) graph_runtime.GraphModule.__init__(self, module)
self._create_debug_env(graph_json_str, ctx) self._create_debug_env(graph_json_str, ctx)
...@@ -222,6 +223,9 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -222,6 +223,9 @@ class GraphModuleDebug(graph_runtime.GraphModule):
# Step 3. Display the collected information # Step 3. Display the collected information
self.debug_datum.display_debug_result() self.debug_datum.display_debug_result()
def run_individual(self, number, repeat=1, min_repeat_ms=0):
self._run_individual(number, repeat, min_repeat_ms)
def exit(self): def exit(self):
"""Exits the dump folder and all its contents""" """Exits the dump folder and all its contents"""
self._remove_dump_root() self._remove_dump_root()
...@@ -39,6 +39,65 @@ class GraphRuntimeDebug : public GraphRuntime { ...@@ -39,6 +39,65 @@ class GraphRuntimeDebug : public GraphRuntime {
} }
/*! /*!
* \brief Run each operation in the graph and print out the runtime per op.
* \param number The number of times to run this function for taking average.
* \param repeat The number of times to repeat the measurement.
In total, the function will be invoked (1 + number x repeat) times,
where the first one is warmed up and will be discarded in case
there is lazy initialization.
* \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
By default, one `repeat` contains `number` runs. If this parameter is set,
the parameters `number` will be dynamically adjusted to meet the
minimum duration requirement of one `repeat`.
*/
void RunIndividual(int number, int repeat, int min_repeat_ms) {
// warmup run
GraphRuntime::Run();
std::vector<double> time_per_op(op_execs_.size(), 0);
for (int i = 0; i < repeat; ++i) {
std::chrono::time_point<
std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
double duration_ms = 0.0;
do {
std::fill(time_per_op.begin(), time_per_op.end(), 0);
if (duration_ms > 0.0) {
number = static_cast<int>(
std::max((min_repeat_ms / (duration_ms / number) + 1),
number * 1.618)); // 1.618 is chosen by random
}
tbegin = std::chrono::high_resolution_clock::now();
for (int k = 0; k < number; k++) {
for (size_t index = 0; index < op_execs_.size(); ++index) {
if (op_execs_[index]) {
const TVMContext& ctx = data_entry_[entry_id(index, 0)]->ctx;
auto op_tbegin = std::chrono::high_resolution_clock::now();
op_execs_[index]();
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_tend = std::chrono::high_resolution_clock::now();
double op_duration = std::chrono::duration_cast<
std::chrono::duration<double> >(op_tend - op_tbegin).count();
time_per_op[index] += op_duration * 1000; // ms
}
}
}
tend = std::chrono::high_resolution_clock::now();
duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
(tend - tbegin).count() * 1000;
} while (duration_ms < min_repeat_ms);
LOG(INFO) << "Repeat: " << i;
int op = 0;
for (size_t index = 0; index < time_per_op.size(); index++) {
if (op_execs_[index]) {
time_per_op[index] /= number;
LOG(INFO) << "Op #" << op++ << ": " << time_per_op[index] << " ms/iter";
}
}
}
}
/*!
* \brief Run each operation and get the output. * \brief Run each operation and get the output.
* \param index The index of op which needs to be returned. * \param index The index of op which needs to be returned.
* \param eid The Entry id of the op. * \param eid The Entry id of the op.
...@@ -119,6 +178,16 @@ PackedFunc GraphRuntimeDebug::GetFunction( ...@@ -119,6 +178,16 @@ PackedFunc GraphRuntimeDebug::GetFunction(
this->DebugGetNodeOutput(args[0], args[1]); this->DebugGetNodeOutput(args[0], args[1]);
} }
}); });
} else if (name == "run_individual") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int number = args[0];
int repeat = args[1];
int min_repeat_ms = args[2];
CHECK_GT(number, 0);
CHECK_GT(repeat, 0);
CHECK_GE(min_repeat_ms, 0);
this->RunIndividual(number, repeat, min_repeat_ms);
});
} else { } else {
return GraphRuntime::GetFunction(name, sptr_to_self); return GraphRuntime::GetFunction(name, sptr_to_self);
} }
......
...@@ -68,6 +68,9 @@ def test_graph_simple(): ...@@ -68,6 +68,9 @@ def test_graph_simple():
out = mod.get_output(0, tvm.nd.empty((n,))) out = mod.get_output(0, tvm.nd.empty((n,)))
np.testing.assert_equal(out.asnumpy(), a + 1) np.testing.assert_equal(out.asnumpy(), a + 1)
#test individual run
mod.run_individual(20, 2, 1)
mod.exit() mod.exit()
#verify dump root delete after cleanup #verify dump root delete after cleanup
assert(not os.path.exists(directory)) assert(not os.path.exists(directory))
...@@ -94,6 +97,7 @@ def test_graph_simple(): ...@@ -94,6 +97,7 @@ def test_graph_simple():
mod.run(x=tvm.nd.array(a, ctx)) mod.run(x=tvm.nd.array(a, ctx))
out = tvm.nd.empty((n,), ctx=ctx) out = tvm.nd.empty((n,), ctx=ctx)
out = mod.get_output(0, out) out = mod.get_output(0, out)
mod.run_individual(20, 2, 1)
np.testing.assert_equal(out.asnumpy(), a + 1) np.testing.assert_equal(out.asnumpy(), a + 1)
check_verify() check_verify()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment