Commit ba8d00c2 by Tianqi Chen Committed by GitHub

[TIMER] Enhance time evaluator to create multiple results (#830)

parent a7cd0a89
"""Container of compiled functions of TVM."""
from __future__ import absolute_import as _abs
import struct
from collections import namedtuple
from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api
from .contrib import cc as _cc, tar as _tar, util as _util
ProfileResult = namedtuple("ProfileResult", ["mean"])
ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
class Module(ModuleBase):
......@@ -110,7 +111,7 @@ class Module(ModuleBase):
fcompile = _cc.create_shared
fcompile(file_name, files, **kwargs)
def time_evaluator(self, func_name, ctx, number):
def time_evaluator(self, func_name, ctx, number, repeat=1):
"""Get an evaluator that measures time cost of running function.
Parameters
......@@ -122,11 +123,15 @@ class Module(ModuleBase):
The context we should run this function on.
number: int
The number of repeative times to run evaluation.
The number of steps used in measuring each time interval
repeat: int, optional
Number of times to run the timer measurement
If repeat equals 3, then we will get 3 numbers in the ProfileResult.
Note
----
The function will be invoked number + 1 times,
The function will be invoked repeat * number + 1 times,
with the first call discarded in case there is lazy initialization.
Returns
......@@ -137,13 +142,16 @@ class Module(ModuleBase):
"""
try:
feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number)
self, func_name, ctx.device_type, ctx.device_id, number, repeat)
def evaluator(*args):
"""Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future.
mean = feval(*args)
return ProfileResult(mean=mean)
blob = feval(*args)
fmt = "@" + ("d" * repeat)
results = struct.unpack(fmt, blob)
mean = sum(results) / float(repeat)
return ProfileResult(mean=mean, results=results)
return evaluator
except NameError:
......
......@@ -77,10 +77,11 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc GetTimeEvaluator(const std::string& name,
TVMContext ctx,
int nstep) {
int number,
int repeat) {
RPCFuncHandle handle = GetFuncHandle(name);
if (handle == nullptr) return PackedFunc();
handle = sess_->GetTimeEvaluator(handle, ctx, nstep);
handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat);
return WrapRemote(handle);
}
......@@ -148,10 +149,10 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
ctx.device_id = args[3];
if (tkey == "rpc") {
*rv = static_cast<RPCModuleNode*>(m.operator->())
->GetTimeEvaluator(args[1], ctx, args[4]);
->GetTimeEvaluator(args[1], ctx, args[4], args[5]);
} else {
*rv = WrapTimeEvaluator(
m.GetFunction(args[1], false), ctx, args[4]);
m.GetFunction(args[1], false), ctx, args[4], args[5]);
}
});
......
......@@ -844,8 +844,9 @@ void RPCSession::CopyFromRemote(void* from,
}
RPCFuncHandle RPCSession::GetTimeEvaluator(
RPCFuncHandle fhandle, TVMContext ctx, int nstep) {
return this->CallRemote(RPCCode::kGetTimeEvaluator, fhandle, ctx, nstep);
RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat) {
return this->CallRemote(
RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat);
}
// Event handler functions
......@@ -973,7 +974,7 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2]));
void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3]));
delete pf;
*rv = fhandle;
}
......@@ -1024,23 +1025,31 @@ void RPCSession::EventHandler::HandlePackedCall() {
CHECK_EQ(state_, kRecvCode);
}
PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int nstep) {
auto ftimer = [pf, ctx, nstep](TVMArgs args, TVMRetValue *rv) {
PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat) {
auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue *rv) {
TVMRetValue temp;
std::ostringstream os;
// skip first time call, to activate lazy compilation components.
pf.CallPacked(args, &temp);
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
for (int i = 0; i < repeat; ++i) {
// start timing
auto tbegin = std::chrono::high_resolution_clock::now();
for (int i = 0; i < nstep; ++i) {
for (int i = 0; i < number; ++i) {
pf.CallPacked(args, &temp);
}
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
auto tend = std::chrono::high_resolution_clock::now();
double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
tend - tbegin).count() / nstep;
tend - tbegin).count() / number;
os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
}
std::string blob = os.str();
TVMByteArray arr;
arr.size = blob.length();
arr.data = blob.data();
// return the time.
*rv = speed;
*rv = arr;
};
return PackedFunc(ftimer);
}
......
......@@ -146,12 +146,14 @@ class RPCSession {
*
* \param fhandle The function handle.
* \param ctx The ctx to run measurement on.
* \param nstep Number of steps to run.
* \param number How many steps to run in each time evaluation
* \param repeat How many times to repeat the timer
* \return A remote timer function
*/
RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
TVMContext ctx,
int nstep);
int number,
int repeat);
/*!
* \brief Call a remote defined system function with arguments.
* \param fcode The function code.
......@@ -212,9 +214,10 @@ class RPCSession {
* \brief Wrap a timer function for a given packed function.
* \param f The function argument.
* \param ctx The context.
* \param nstep Number of repeative steps.
* \param number Number of steps in the inner iteration
* \param repeat How many steps to repeat the time evaluation.
*/
PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep);
PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int number, int repeat);
/*!
* \brief Create a Global RPC module that refers to the session.
......
......@@ -55,7 +55,10 @@ def test_log_pow_llvm():
n = 1028
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
flog(a, b)
repeat = 10
ftimer = flog.time_evaluator(flog.entry_name, ctx, number=1, repeat=repeat)
res = ftimer(a, b)
assert(len(res.results) == repeat)
np.testing.assert_allclose(
b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
......@@ -146,7 +149,7 @@ def test_add():
if __name__ == "__main__":
test_add()
test_log_pow_llvm()
test_popcount()
test_exp()
test_add()
test_popcount()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment