Commit 6b11ffb9 by Zhi Committed by Yao Wang

Sort VM stats by time (#4601)

parent a8c36921
......@@ -38,8 +38,21 @@ class VirtualMachineProfiler(vm.VirtualMachine):
self._set_input = self.mod["set_input"]
self._reset = self.mod["reset"]
def get_stat(self):
return self._get_stat()
def get_stat(self, sort_by_time=True):
"""Get the statistics of executed ops.
Parameters
----------
sort_by_time: Optional[Boolean]
Set to indicate the returned results are sorted by execution time in
the descending order. It is printed in the random order if this
field is not set.
Returns
-------
The execution statistics in string.
"""
return self._get_stat(sort_by_time)
def reset(self):
self._reset()
......@@ -31,6 +31,7 @@
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "vm.h"
......@@ -43,16 +44,32 @@ PackedFunc VirtualMachineDebug::GetFunction(
const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_stat") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1U);
std::vector<std::pair<Index, double>> op_acc_time;
for (auto kv : op_durations_) {
auto val = std::make_pair(
kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0));
op_acc_time.push_back(val);
}
bool sort_by_time = args[0];
if (sort_by_time) {
auto comp = [](const std::pair<Index, double>& lhs,
const std::pair<Index, double>& rhs) {
return lhs.second > rhs.second;
};
std::sort(op_acc_time.begin(), op_acc_time.end(), comp);
}
double total_duration = 0.0;
int64_t total_packed_funcs = 0;
std::ostringstream os;
os << std::setw(30) << std::left << "#OpName"
<< "\t" << std::setw(10) << std::left << "#InvokeCount"
<< "\t"
<< "#Duration(us): Sum/Mean/Min/Max" << std::endl;
for (auto kv : op_durations_) {
for (auto kv : op_acc_time) {
auto vals = op_durations_[kv.first];
auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);;
auto sum = kv.second;
auto mean = sum / static_cast<double>(vals.size());
auto min_value = *std::min_element(vals.begin(), vals.end());
auto max_value = *std::max_element(vals.begin(), vals.end());
......@@ -62,8 +79,10 @@ PackedFunc VirtualMachineDebug::GetFunction(
<< sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl;
total_duration += sum;
total_packed_funcs += op_invokes_[kv.first];
}
os << "Total Duration " << total_duration << " us" << std::endl;
os << "\nTotal Duration: " << total_duration << " us.\t"
<< "Total Packed Functions: " << total_packed_funcs << std::endl;
*rv = os.str();
});
} else if (name == "reset") {
......
......@@ -35,6 +35,7 @@ def test_basic():
data = np.random.rand(1, 3, 224, 224).astype('float32')
res = vm.invoke("main", [data])
print("\n{}".format(vm.get_stat()))
print("\n{}".format(vm.get_stat(False)))
if __name__ == "__main__":
test_basic()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment