Commit 47b8c36d by Siju Committed by Tianqi Chen

[RUNTIME][DEBUG]Support remote debugging (#1866)

parent 0b4cc050
......@@ -5,8 +5,9 @@ import tempfile
import shutil
from datetime import datetime
from tvm._ffi.base import string_types
from tvm.contrib import graph_runtime
from tvm._ffi.function import get_global_func
from tvm.contrib import graph_runtime
from tvm.rpc import base as rpc_base
from . import debug_result
_DUMP_ROOT_PREFIX = "tvmdbg_"
......@@ -49,8 +50,12 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
raise NotSupportedError("Remote graph debugging is not supported.")
libmod = rpc_base._ModuleHandle(libmod)
try:
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_debug.remote_create")
except ValueError:
raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \
"config.cmake and rebuild TVM to enable debug mode")
func_obj = fcreate(graph_json_str, libmod, *device_type_id)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)
......
......@@ -146,5 +146,18 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
<< args.num_args;
*rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
});
TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeDebugCreate(
args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
});
} // namespace runtime
} // namespace tvm
......@@ -2,6 +2,8 @@ import os
import tvm
import numpy as np
import json
from tvm import rpc
from tvm.contrib import util
from tvm.contrib.debugger import debug_runtime as graph_runtime
def test_graph_simple():
......@@ -70,7 +72,32 @@ def test_graph_simple():
#verify dump root delete after cleanup
assert(not os.path.exists(directory))
def check_remote():
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
mlib = tvm.build(s, [A, B], "llvm", name="myadd")
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
temp = util.tempdir()
ctx = remote.cpu(0)
path_dso = temp.relpath("dev_lib.so")
mlib.export_library(path_dso)
remote.upload(path_dso)
mlib = remote.load_module("dev_lib.so")
try:
mod = graph_runtime.create(graph, mlib, remote.cpu(0))
except ValueError:
print("Skip because debug graph_runtime not enabled")
return
a = np.random.uniform(size=(n,)).astype(A.dtype)
mod.run(x=tvm.nd.array(a, ctx))
out = tvm.nd.empty((n,), ctx=ctx)
out = mod.get_output(0, out)
np.testing.assert_equal(out.asnumpy(), a + 1)
check_verify()
check_remote()
if __name__ == "__main__":
test_graph_simple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment