Commit ba8d00c2 by Tianqi Chen Committed by GitHub

[TIMER] Enhance time evaluator to create multiple results (#830)

parent a7cd0a89
"""Container of compiled functions of TVM.""" """Container of compiled functions of TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import struct
from collections import namedtuple from collections import namedtuple
from ._ffi.function import ModuleBase, _set_class_module from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .contrib import cc as _cc, tar as _tar, util as _util from .contrib import cc as _cc, tar as _tar, util as _util
ProfileResult = namedtuple("ProfileResult", ["mean"]) ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
class Module(ModuleBase): class Module(ModuleBase):
...@@ -110,7 +111,7 @@ class Module(ModuleBase): ...@@ -110,7 +111,7 @@ class Module(ModuleBase):
fcompile = _cc.create_shared fcompile = _cc.create_shared
fcompile(file_name, files, **kwargs) fcompile(file_name, files, **kwargs)
def time_evaluator(self, func_name, ctx, number): def time_evaluator(self, func_name, ctx, number, repeat=1):
"""Get an evaluator that measures time cost of running function. """Get an evaluator that measures time cost of running function.
Parameters Parameters
...@@ -122,11 +123,15 @@ class Module(ModuleBase): ...@@ -122,11 +123,15 @@ class Module(ModuleBase):
The context we should run this function on. The context we should run this function on.
number: int number: int
The number of repeative times to run evaluation. The number of steps used in measuring each time interval
repeat: int, optional
Number of times to run the timer measurement
If repeat equals 3, then we will get 3 numbers in the ProfileResult.
Note Note
---- ----
The function will be invoked number + 1 times, The function will be invoked repeat * number + 1 times,
with the first call discarded in case there is lazy initialization. with the first call discarded in case there is lazy initialization.
Returns Returns
...@@ -137,13 +142,16 @@ class Module(ModuleBase): ...@@ -137,13 +142,16 @@ class Module(ModuleBase):
""" """
try: try:
feval = _RPCTimeEvaluator( feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number) self, func_name, ctx.device_type, ctx.device_id, number, repeat)
def evaluator(*args): def evaluator(*args):
"""Internal wrapped evaluator.""" """Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future. # Wrap feval so we can add more stats in future.
mean = feval(*args) blob = feval(*args)
return ProfileResult(mean=mean) fmt = "@" + ("d" * repeat)
results = struct.unpack(fmt, blob)
mean = sum(results) / float(repeat)
return ProfileResult(mean=mean, results=results)
return evaluator return evaluator
except NameError: except NameError:
......
...@@ -77,10 +77,11 @@ class RPCModuleNode final : public ModuleNode { ...@@ -77,10 +77,11 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc GetTimeEvaluator(const std::string& name, PackedFunc GetTimeEvaluator(const std::string& name,
TVMContext ctx, TVMContext ctx,
int nstep) { int number,
int repeat) {
RPCFuncHandle handle = GetFuncHandle(name); RPCFuncHandle handle = GetFuncHandle(name);
if (handle == nullptr) return PackedFunc(); if (handle == nullptr) return PackedFunc();
handle = sess_->GetTimeEvaluator(handle, ctx, nstep); handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat);
return WrapRemote(handle); return WrapRemote(handle);
} }
...@@ -148,10 +149,10 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator") ...@@ -148,10 +149,10 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
ctx.device_id = args[3]; ctx.device_id = args[3];
if (tkey == "rpc") { if (tkey == "rpc") {
*rv = static_cast<RPCModuleNode*>(m.operator->()) *rv = static_cast<RPCModuleNode*>(m.operator->())
->GetTimeEvaluator(args[1], ctx, args[4]); ->GetTimeEvaluator(args[1], ctx, args[4], args[5]);
} else { } else {
*rv = WrapTimeEvaluator( *rv = WrapTimeEvaluator(
m.GetFunction(args[1], false), ctx, args[4]); m.GetFunction(args[1], false), ctx, args[4], args[5]);
} }
}); });
......
...@@ -844,8 +844,9 @@ void RPCSession::CopyFromRemote(void* from, ...@@ -844,8 +844,9 @@ void RPCSession::CopyFromRemote(void* from,
} }
RPCFuncHandle RPCSession::GetTimeEvaluator( RPCFuncHandle RPCSession::GetTimeEvaluator(
RPCFuncHandle fhandle, TVMContext ctx, int nstep) { RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat) {
return this->CallRemote(RPCCode::kGetTimeEvaluator, fhandle, ctx, nstep); return this->CallRemote(
RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat);
} }
// Event handler functions // Event handler functions
...@@ -973,7 +974,7 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { ...@@ -973,7 +974,7 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*()); PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2])); void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3]));
delete pf; delete pf;
*rv = fhandle; *rv = fhandle;
} }
...@@ -1024,23 +1025,31 @@ void RPCSession::EventHandler::HandlePackedCall() { ...@@ -1024,23 +1025,31 @@ void RPCSession::EventHandler::HandlePackedCall() {
CHECK_EQ(state_, kRecvCode); CHECK_EQ(state_, kRecvCode);
} }
PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int nstep) { PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat) {
auto ftimer = [pf, ctx, nstep](TVMArgs args, TVMRetValue *rv) { auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue *rv) {
TVMRetValue temp; TVMRetValue temp;
std::ostringstream os;
// skip first time call, to activate lazy compilation components. // skip first time call, to activate lazy compilation components.
pf.CallPacked(args, &temp); pf.CallPacked(args, &temp);
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
for (int i = 0; i < repeat; ++i) {
// start timing // start timing
auto tbegin = std::chrono::high_resolution_clock::now(); auto tbegin = std::chrono::high_resolution_clock::now();
for (int i = 0; i < nstep; ++i) { for (int i = 0; i < number; ++i) {
pf.CallPacked(args, &temp); pf.CallPacked(args, &temp);
} }
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
auto tend = std::chrono::high_resolution_clock::now(); auto tend = std::chrono::high_resolution_clock::now();
double speed = std::chrono::duration_cast<std::chrono::duration<double> >( double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
tend - tbegin).count() / nstep; tend - tbegin).count() / number;
os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
}
std::string blob = os.str();
TVMByteArray arr;
arr.size = blob.length();
arr.data = blob.data();
// return the time. // return the time.
*rv = speed; *rv = arr;
}; };
return PackedFunc(ftimer); return PackedFunc(ftimer);
} }
......
...@@ -146,12 +146,14 @@ class RPCSession { ...@@ -146,12 +146,14 @@ class RPCSession {
* *
* \param fhandle The function handle. * \param fhandle The function handle.
* \param ctx The ctx to run measurement on. * \param ctx The ctx to run measurement on.
* \param nstep Number of steps to run. * \param number How many steps to run in each time evaluation
* \param repeat How many times to repeat the timer
* \return A remote timer function * \return A remote timer function
*/ */
RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle, RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
TVMContext ctx, TVMContext ctx,
int nstep); int number,
int repeat);
/*! /*!
* \brief Call a remote defined system function with arguments. * \brief Call a remote defined system function with arguments.
* \param fcode The function code. * \param fcode The function code.
...@@ -212,9 +214,10 @@ class RPCSession { ...@@ -212,9 +214,10 @@ class RPCSession {
* \brief Wrap a timer function for a given packed function. * \brief Wrap a timer function for a given packed function.
* \param f The function argument. * \param f The function argument.
* \param ctx The context. * \param ctx The context.
* \param nstep Number of repeative steps. * \param number Number of steps in the inner iteration
* \param repeat How many steps to repeat the time evaluation.
*/ */
PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep); PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int number, int repeat);
/*! /*!
* \brief Create a Global RPC module that refers to the session. * \brief Create a Global RPC module that refers to the session.
......
...@@ -55,7 +55,10 @@ def test_log_pow_llvm(): ...@@ -55,7 +55,10 @@ def test_log_pow_llvm():
n = 1028 n = 1028
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
flog(a, b) repeat = 10
ftimer = flog.time_evaluator(flog.entry_name, ctx, number=1, repeat=repeat)
res = ftimer(a, b)
assert(len(res.results) == repeat)
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5) b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
...@@ -146,7 +149,7 @@ def test_add(): ...@@ -146,7 +149,7 @@ def test_add():
if __name__ == "__main__": if __name__ == "__main__":
test_add()
test_log_pow_llvm() test_log_pow_llvm()
test_popcount()
test_exp() test_exp()
test_add()
test_popcount()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment