Commit 5912ed03 by Tianqi Chen Committed by GitHub

[PERF/TIMER] Add builtin timing logic (#168)

* [PERF/TIMER] Add buildin timing logic

* fix lint
parent 46b4a914
......@@ -56,10 +56,12 @@ class Function(_FunctionBase):
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry"]
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
......@@ -75,7 +77,7 @@ class ModuleBase(object):
"""
if self._entry:
return self._entry
self._entry = self.get_function("__tvm_main__")
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
......
......@@ -72,7 +72,7 @@ class Module(ModuleBase):
The name of the shared library.
"""
if self.type_key != "llvm":
raise ValueError("Only llvm support export shared")
raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key)
temp = _util.tempdir()
path_obj = temp.relpath("lib.o")
self.save(path_obj)
......@@ -84,6 +84,37 @@ class Module(ModuleBase):
files.append(path_cc)
_cc.create_shared(file_name, files)
def time_evaluator(self, func_name, ctx, number):
"""Get an evaluator that measures time cost of running function.
Parameters
----------
func_name: str
The name of the function in the module.
ctx: TVMContext
The context we should run this function on.
number: int
The number of repeative times to run evaluation.
Note
----
The function will be invoked number + 1 times,
with the first call discarded in case there is lazy initialization.
Returns
-------
ftimer : Function
The function that takes same argument as func
and return a float representing seconds per function call.
"""
try:
return _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number)
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")
def load(path, fmt=""):
"""Load module from file
......
......@@ -51,18 +51,8 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
RPCFuncHandle handle = nullptr;
if (module_handle_ == nullptr) {
handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
} else {
handle = sess_->CallRemote(
RPCCode::kModuleGetFunc, module_handle_, name);
}
if (handle == nullptr) return PackedFunc();
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
RPCFuncHandle handle = GetFuncHandle(name);
return WrapRemote(handle);
}
void SaveToFile(const std::string& file_name,
......@@ -86,7 +76,34 @@ class RPCModuleNode final : public ModuleNode {
return sess_;
}
PackedFunc GetTimeEvaluator(const std::string& name,
TVMContext ctx,
int nstep) {
RPCFuncHandle handle = GetFuncHandle(name);
if (handle == nullptr) return PackedFunc();
handle = sess_->GetTimeEvaluator(handle, ctx, nstep);
return WrapRemote(handle);
}
private:
PackedFunc WrapRemote(RPCFuncHandle handle) {
if (handle == nullptr) return PackedFunc();
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
}
RPCFuncHandle GetFuncHandle(const std::string& name) {
RPCFuncHandle handle = nullptr;
if (module_handle_ == nullptr) {
handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
} else {
handle = sess_->CallRemote(
RPCCode::kModuleGetFunc, module_handle_, name);
}
return handle;
}
// The module handle
void* module_handle_{nullptr};
// The local channel
......@@ -123,6 +140,22 @@ TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
*rv = RPCConnect(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[2].operator int());
ctx.device_id = args[3];
if (tkey == "rpc") {
*rv = static_cast<RPCModuleNode*>(m.operator->())
->GetTimeEvaluator(args[1], ctx, args[4]);
} else {
*rv = WrapTimeEvaluator(
m.GetFunction(args[1], false), ctx, args[3]);
}
});
TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
......
......@@ -6,6 +6,7 @@
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <array>
#include <chrono>
#include "./rpc_session.h"
#include "../device_api.h"
......@@ -181,6 +182,11 @@ void RPCSession::CopyFromRemote(void* from,
}
}
RPCFuncHandle RPCSession::GetTimeEvaluator(
RPCFuncHandle fhandle, TVMContext ctx, int nstep) {
return this->CallRemote(RPCCode::kGetTimeEvaluator, fhandle, ctx, nstep);
}
void RPCSession::SendReturnValue(
int succ, TVMValue ret_value, int ret_tcode) {
if (succ == 0) {
......@@ -593,6 +599,13 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
*rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
}
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2]));
delete pf;
*rv = fhandle;
}
RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
RPCCode code;
CHECK_EQ(sock_.RecvAll(&code, sizeof(int)), sizeof(int));
......@@ -604,6 +617,7 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
case RPCCode::kCopyToRemote: HandleCopyToRemote(); break;
case RPCCode::kShutdown: break;
// system functions
case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break;
case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break;
case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break;
......@@ -620,5 +634,26 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
}
return code;
}
PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int nstep) {
auto ftimer = [pf, ctx, nstep](TVMArgs args, TVMRetValue *rv) {
TVMRetValue temp;
// skip first time call, to activate lazy compilation components.
pf.CallPacked(args, &temp);
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
// start timing
auto tbegin = std::chrono::high_resolution_clock::now();
for (int i = 0; i < nstep; ++i) {
pf.CallPacked(args, &temp);
}
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
auto tend = std::chrono::high_resolution_clock::now();
double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
tend - tbegin).count() / nstep;
// return the time.
*rv = speed;
};
return PackedFunc(ftimer);
}
} // namespace runtime
} // namespace tvm
......@@ -31,6 +31,7 @@ enum class RPCCode : int {
kCopyAck,
// The following are code that can send over CallRemote
kGetGlobalFunc,
kGetTimeEvaluator,
kFreeFunc,
kDevSetDevice,
kDevGetAttr,
......@@ -93,6 +94,18 @@ class RPCSession {
size_t size,
TVMContext ctx_from);
/*!
* \brief Get a remote timer function on ctx.
* This function consumes fhandle, caller should not call Free on fhandle.
*
* \param fhandle The function handle.
* \param ctx The ctx to run measurement on.
* \param nstep Number of steps to run.
* \return A remote timer function
*/
RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
TVMContext ctx,
int nstep);
/*!
* \brief Call a remote defined system function with arguments.
* \param fcode The function code.
* \param args The arguments
......@@ -133,13 +146,13 @@ class RPCSession {
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n);
void RecvPackedSeq(RPCArgBuffer *buf);
RPCCode HandleNextEvent(TVMRetValue *rv);
TVMContext StripSessMask(TVMContext ctx);
// special handler.
void HandleCallFunc();
void HandleException();
void HandleCopyFromRemote();
void HandleCopyToRemote();
void HandleReturn(TVMRetValue* rv);
TVMContext StripSessMask(TVMContext ctx);
// Internal mutex
std::recursive_mutex mutex_;
// Internal socket
......@@ -152,6 +165,14 @@ class RPCSession {
int table_index_{0};
};
/*!
* \brief Wrap a timer function for a given packed function.
* \param f The function argument.
* \param ctx The context.
* \param nstep Number of repeative steps.
*/
PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep);
// Remote space pointer.
struct RemoteSpace {
void* data;
......
......@@ -95,7 +95,8 @@ def test_add():
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
vbias = np.random.uniform()
vscale = np.random.uniform()
fadd(a, b, c, vbias, vscale)
ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=1000)
tcost = ftimer(a, b, c, vbias, vscale)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6)
......
......@@ -78,14 +78,9 @@ def test_gemm():
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
ctx.sync()
tbegin = time.time()
f(a, b, c)
tpush = time.time()
ctx.sync()
tend = time.time()
print("launch=%g sec, exec=%g sec" % (tpush - tbegin, tend - tbegin))
ftimer = f.time_evaluator(f.entry_name, ctx, number=20)
tcost = ftimer(a, b, c)
print("%s: exec=%g sec/op" % (ctx, tcost))
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
......
......@@ -70,7 +70,9 @@ def test_rpc_remote_module():
f1 = remote.load_module("dev_lib.so")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b)
time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b)
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment