Commit 5912ed03 by Tianqi Chen Committed by GitHub

[PERF/TIMER] Add builtin timing logic (#168)

* [PERF/TIMER] Add buildin timing logic

* fix lint
parent 46b4a914
...@@ -56,10 +56,12 @@ class Function(_FunctionBase): ...@@ -56,10 +56,12 @@ class Function(_FunctionBase):
class ModuleBase(object): class ModuleBase(object):
"""Base class for module""" """Base class for module"""
__slots__ = ["handle", "_entry"] __slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle): def __init__(self, handle):
self.handle = handle self.handle = handle
self._entry = None self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self): def __del__(self):
check_call(_LIB.TVMModFree(self.handle)) check_call(_LIB.TVMModFree(self.handle))
...@@ -75,7 +77,7 @@ class ModuleBase(object): ...@@ -75,7 +77,7 @@ class ModuleBase(object):
""" """
if self._entry: if self._entry:
return self._entry return self._entry
self._entry = self.get_function("__tvm_main__") self._entry = self.get_function(self.entry_name)
return self._entry return self._entry
def get_function(self, name, query_imports=False): def get_function(self, name, query_imports=False):
......
...@@ -72,7 +72,7 @@ class Module(ModuleBase): ...@@ -72,7 +72,7 @@ class Module(ModuleBase):
The name of the shared library. The name of the shared library.
""" """
if self.type_key != "llvm": if self.type_key != "llvm":
raise ValueError("Only llvm support export shared") raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key)
temp = _util.tempdir() temp = _util.tempdir()
path_obj = temp.relpath("lib.o") path_obj = temp.relpath("lib.o")
self.save(path_obj) self.save(path_obj)
...@@ -84,6 +84,37 @@ class Module(ModuleBase): ...@@ -84,6 +84,37 @@ class Module(ModuleBase):
files.append(path_cc) files.append(path_cc)
_cc.create_shared(file_name, files) _cc.create_shared(file_name, files)
def time_evaluator(self, func_name, ctx, number):
"""Get an evaluator that measures time cost of running function.
Parameters
----------
func_name: str
The name of the function in the module.
ctx: TVMContext
The context we should run this function on.
number: int
The number of repeative times to run evaluation.
Note
----
The function will be invoked number + 1 times,
with the first call discarded in case there is lazy initialization.
Returns
-------
ftimer : Function
The function that takes same argument as func
and return a float representing seconds per function call.
"""
try:
return _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number)
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")
def load(path, fmt=""): def load(path, fmt=""):
"""Load module from file """Load module from file
......
...@@ -51,18 +51,8 @@ class RPCModuleNode final : public ModuleNode { ...@@ -51,18 +51,8 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
RPCFuncHandle handle = nullptr; RPCFuncHandle handle = GetFuncHandle(name);
if (module_handle_ == nullptr) { return WrapRemote(handle);
handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
} else {
handle = sess_->CallRemote(
RPCCode::kModuleGetFunc, module_handle_, name);
}
if (handle == nullptr) return PackedFunc();
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
} }
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
...@@ -86,7 +76,34 @@ class RPCModuleNode final : public ModuleNode { ...@@ -86,7 +76,34 @@ class RPCModuleNode final : public ModuleNode {
return sess_; return sess_;
} }
PackedFunc GetTimeEvaluator(const std::string& name,
TVMContext ctx,
int nstep) {
RPCFuncHandle handle = GetFuncHandle(name);
if (handle == nullptr) return PackedFunc();
handle = sess_->GetTimeEvaluator(handle, ctx, nstep);
return WrapRemote(handle);
}
private: private:
PackedFunc WrapRemote(RPCFuncHandle handle) {
if (handle == nullptr) return PackedFunc();
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
}
RPCFuncHandle GetFuncHandle(const std::string& name) {
RPCFuncHandle handle = nullptr;
if (module_handle_ == nullptr) {
handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
} else {
handle = sess_->CallRemote(
RPCCode::kModuleGetFunc, module_handle_, name);
}
return handle;
}
// The module handle // The module handle
void* module_handle_{nullptr}; void* module_handle_{nullptr};
// The local channel // The local channel
...@@ -123,6 +140,22 @@ TVM_REGISTER_GLOBAL("contrib.rpc._Connect") ...@@ -123,6 +140,22 @@ TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
*rv = RPCConnect(args[0], args[1]); *rv = RPCConnect(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[2].operator int());
ctx.device_id = args[3];
if (tkey == "rpc") {
*rv = static_cast<RPCModuleNode*>(m.operator->())
->GetTimeEvaluator(args[1], ctx, args[4]);
} else {
*rv = WrapTimeEvaluator(
m.GetFunction(args[1], false), ctx, args[3]);
}
});
TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule") TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0]; Module m = args[0];
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <memory> #include <memory>
#include <array> #include <array>
#include <chrono>
#include "./rpc_session.h" #include "./rpc_session.h"
#include "../device_api.h" #include "../device_api.h"
...@@ -181,6 +182,11 @@ void RPCSession::CopyFromRemote(void* from, ...@@ -181,6 +182,11 @@ void RPCSession::CopyFromRemote(void* from,
} }
} }
RPCFuncHandle RPCSession::GetTimeEvaluator(
RPCFuncHandle fhandle, TVMContext ctx, int nstep) {
return this->CallRemote(RPCCode::kGetTimeEvaluator, fhandle, ctx, nstep);
}
void RPCSession::SendReturnValue( void RPCSession::SendReturnValue(
int succ, TVMValue ret_value, int ret_tcode) { int succ, TVMValue ret_value, int ret_tcode) {
if (succ == 0) { if (succ == 0) {
...@@ -593,6 +599,13 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { ...@@ -593,6 +599,13 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
*rv = (*static_cast<Module*>(mhandle))->GetSource(fmt); *rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
} }
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2]));
delete pf;
*rv = fhandle;
}
RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) { RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
RPCCode code; RPCCode code;
CHECK_EQ(sock_.RecvAll(&code, sizeof(int)), sizeof(int)); CHECK_EQ(sock_.RecvAll(&code, sizeof(int)), sizeof(int));
...@@ -604,6 +617,7 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) { ...@@ -604,6 +617,7 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
case RPCCode::kCopyToRemote: HandleCopyToRemote(); break; case RPCCode::kCopyToRemote: HandleCopyToRemote(); break;
case RPCCode::kShutdown: break; case RPCCode::kShutdown: break;
// system functions // system functions
case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break;
case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break; case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break; case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break;
case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break; case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break;
...@@ -620,5 +634,26 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) { ...@@ -620,5 +634,26 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
} }
return code; return code;
} }
PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int nstep) {
auto ftimer = [pf, ctx, nstep](TVMArgs args, TVMRetValue *rv) {
TVMRetValue temp;
// skip first time call, to activate lazy compilation components.
pf.CallPacked(args, &temp);
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
// start timing
auto tbegin = std::chrono::high_resolution_clock::now();
for (int i = 0; i < nstep; ++i) {
pf.CallPacked(args, &temp);
}
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
auto tend = std::chrono::high_resolution_clock::now();
double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
tend - tbegin).count() / nstep;
// return the time.
*rv = speed;
};
return PackedFunc(ftimer);
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -31,6 +31,7 @@ enum class RPCCode : int { ...@@ -31,6 +31,7 @@ enum class RPCCode : int {
kCopyAck, kCopyAck,
// The following are code that can send over CallRemote // The following are code that can send over CallRemote
kGetGlobalFunc, kGetGlobalFunc,
kGetTimeEvaluator,
kFreeFunc, kFreeFunc,
kDevSetDevice, kDevSetDevice,
kDevGetAttr, kDevGetAttr,
...@@ -93,6 +94,18 @@ class RPCSession { ...@@ -93,6 +94,18 @@ class RPCSession {
size_t size, size_t size,
TVMContext ctx_from); TVMContext ctx_from);
/*! /*!
* \brief Get a remote timer function on ctx.
* This function consumes fhandle, caller should not call Free on fhandle.
*
* \param fhandle The function handle.
* \param ctx The ctx to run measurement on.
* \param nstep Number of steps to run.
* \return A remote timer function
*/
RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
TVMContext ctx,
int nstep);
/*!
* \brief Call a remote defined system function with arguments. * \brief Call a remote defined system function with arguments.
* \param fcode The function code. * \param fcode The function code.
* \param args The arguments * \param args The arguments
...@@ -133,13 +146,13 @@ class RPCSession { ...@@ -133,13 +146,13 @@ class RPCSession {
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n); void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n);
void RecvPackedSeq(RPCArgBuffer *buf); void RecvPackedSeq(RPCArgBuffer *buf);
RPCCode HandleNextEvent(TVMRetValue *rv); RPCCode HandleNextEvent(TVMRetValue *rv);
TVMContext StripSessMask(TVMContext ctx);
// special handler. // special handler.
void HandleCallFunc(); void HandleCallFunc();
void HandleException(); void HandleException();
void HandleCopyFromRemote(); void HandleCopyFromRemote();
void HandleCopyToRemote(); void HandleCopyToRemote();
void HandleReturn(TVMRetValue* rv); void HandleReturn(TVMRetValue* rv);
TVMContext StripSessMask(TVMContext ctx);
// Internal mutex // Internal mutex
std::recursive_mutex mutex_; std::recursive_mutex mutex_;
// Internal socket // Internal socket
...@@ -152,6 +165,14 @@ class RPCSession { ...@@ -152,6 +165,14 @@ class RPCSession {
int table_index_{0}; int table_index_{0};
}; };
/*!
* \brief Wrap a timer function for a given packed function.
* \param f The function argument.
* \param ctx The context.
* \param nstep Number of repeative steps.
*/
PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep);
// Remote space pointer. // Remote space pointer.
struct RemoteSpace { struct RemoteSpace {
void* data; void* data;
......
...@@ -95,7 +95,8 @@ def test_add(): ...@@ -95,7 +95,8 @@ def test_add():
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
vbias = np.random.uniform() vbias = np.random.uniform()
vscale = np.random.uniform() vscale = np.random.uniform()
fadd(a, b, c, vbias, vscale) ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=1000)
tcost = ftimer(a, b, c, vbias, vscale)
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6) c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6)
......
...@@ -78,14 +78,9 @@ def test_gemm(): ...@@ -78,14 +78,9 @@ def test_gemm():
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c) ftimer = f.time_evaluator(f.entry_name, ctx, number=20)
ctx.sync() tcost = ftimer(a, b, c)
tbegin = time.time() print("%s: exec=%g sec/op" % (ctx, tcost))
f(a, b, c)
tpush = time.time()
ctx.sync()
tend = time.time()
print("launch=%g sec, exec=%g sec" % (tpush - tbegin, tend - tbegin))
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
......
...@@ -70,7 +70,9 @@ def test_rpc_remote_module(): ...@@ -70,7 +70,9 @@ def test_rpc_remote_module():
f1 = remote.load_module("dev_lib.so") f1 = remote.load_module("dev_lib.so")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b) time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b)
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote() check_remote()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment