Commit 81db22c5 by Tianqi Chen Committed by GitHub

[RPC] graduate tvm.contrib.rpc -> tvm.rpc (#1410)

parent 68e4a111
...@@ -6,7 +6,8 @@ And configure the proxy host field as commented. ...@@ -6,7 +6,8 @@ And configure the proxy host field as commented.
import tvm import tvm
import os import os
from tvm.contrib import rpc, util, ndk from tvm import rpc
from tvm.contrib import util, ndk
import numpy as np import numpy as np
# Set to be address of tvm proxy. # Set to be address of tvm proxy.
......
...@@ -6,7 +6,8 @@ And configure the proxy host field as commented. ...@@ -6,7 +6,8 @@ And configure the proxy host field as commented.
import tvm import tvm
import os import os
from tvm.contrib import rpc, util, xcode from tvm import rpc
from tvm.contrib import util, xcode
import numpy as np import numpy as np
# Set to be address of tvm proxy. # Set to be address of tvm proxy.
......
...@@ -103,13 +103,13 @@ void LaunchSyncServer() { ...@@ -103,13 +103,13 @@ void LaunchSyncServer() {
->ServerLoop(); ->ServerLoop();
} }
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath") TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env; static RPCEnv env;
*rv = env.GetPath(args[0]); *rv = env.GetPath(args[0]);
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module") TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
std::string name = args[0]; std::string name = args[0];
std::string fmt = GetFileFormat(name, ""); std::string fmt = GetFileFormat(name, "");
......
tvm.contrib.rpc tvm.rpc
--------------- -------
.. automodule:: tvm.contrib.rpc .. automodule:: tvm.rpc
.. autofunction:: tvm.contrib.rpc.connect .. autofunction:: tvm.rpc.connect
.. autofunction:: tvm.contrib.rpc.connect_tracker .. autofunction:: tvm.rpc.connect_tracker
.. autoclass:: tvm.contrib.rpc.TrackerSession .. autoclass:: tvm.rpc.TrackerSession
:members: :members:
:inherited-members: :inherited-members:
.. autoclass:: tvm.contrib.rpc.RPCSession .. autoclass:: tvm.rpc.RPCSession
:members: :members:
:inherited-members: :inherited-members:
.. autoclass:: tvm.contrib.rpc.LocalSession .. autoclass:: tvm.rpc.LocalSession
:members: :members:
:inherited-members: :inherited-members:
.. autoclass:: tvm.contrib.rpc.Server .. autoclass:: tvm.rpc.Server
:members: :members:
:inherited-members: :inherited-members:
...@@ -19,10 +19,10 @@ We can then use the following command to launch a `tvmai/demo-cpu` image. ...@@ -19,10 +19,10 @@ We can then use the following command to launch a `tvmai/demo-cpu` image.
.. code:: bash .. code:: bash
/path/to/tvm/docker/bash.sh tvmai/demo_cpu /path/to/tvm/docker/bash.sh tvmai/demo-cpu
.. note:: You can also change `demo-cpu` to `demo-gpu` to get a CUDA enabled image.
You can find all the prebuilt images in `<https://hub.docker.com/r/tvmai/>`_ You can find all the prebuilt images in `<https://hub.docker.com/r/tvmai/>`_
This auxiliary script does the following things: This auxiliary script does the following things:
......
...@@ -70,13 +70,13 @@ public class NativeServerLoop implements Runnable { ...@@ -70,13 +70,13 @@ public class NativeServerLoop implements Runnable {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath()); throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
} }
Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() { Function.register("tvm.rpc.server.workpath", new Function.Callback() {
@Override public Object invoke(TVMValue... args) { @Override public Object invoke(TVMValue... args) {
return tempDir + File.separator + args[0].asString(); return tempDir + File.separator + args[0].asString();
} }
}, true); }, true);
Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() { Function.register("tvm.rpc.server.load_module", new Function.Callback() {
@Override public Object invoke(TVMValue... args) { @Override public Object invoke(TVMValue... args) {
String filename = args[0].asString(); String filename = args[0].asString();
String path = tempDir + File.separator + filename; String path = tempDir + File.separator + filename;
......
...@@ -37,7 +37,7 @@ public class RPC { ...@@ -37,7 +37,7 @@ public class RPC {
static Function getApi(String name) { static Function getApi(String name) {
Function func = apiFuncs.get().get(name); Function func = apiFuncs.get().get(name);
if (func == null) { if (func == null) {
func = Function.getFunction("contrib.rpc." + name); func = Function.getFunction("rpc." + name);
if (func == null) { if (func == null) {
return null; return null;
} }
......
...@@ -172,7 +172,7 @@ public class RPCSession { ...@@ -172,7 +172,7 @@ public class RPCSession {
final String funcName = "upload"; final String funcName = "upload";
Function remoteFunc = remoteFuncs.get(funcName); Function remoteFunc = remoteFuncs.get(funcName);
if (remoteFunc == null) { if (remoteFunc == null) {
remoteFunc = getFunction("tvm.contrib.rpc.server.upload"); remoteFunc = getFunction("tvm.rpc.server.upload");
remoteFuncs.put(funcName, remoteFunc); remoteFuncs.put(funcName, remoteFunc);
} }
remoteFunc.pushArg(target).pushArg(data).invoke(); remoteFunc.pushArg(target).pushArg(data).invoke();
...@@ -205,7 +205,7 @@ public class RPCSession { ...@@ -205,7 +205,7 @@ public class RPCSession {
final String name = "download"; final String name = "download";
Function func = remoteFuncs.get(name); Function func = remoteFuncs.get(name);
if (func == null) { if (func == null) {
func = getFunction("tvm.contrib.rpc.server.download"); func = getFunction("tvm.rpc.server.download");
remoteFuncs.put(name, func); remoteFuncs.put(name, func);
} }
return func.pushArg(path).invoke().asBytes(); return func.pushArg(path).invoke().asBytes();
......
import time import time
from tvm.contrib.rpc import proxy from tvm.rpc import proxy
def start_proxy_server(port, timeout): def start_proxy_server(port, timeout):
prox = proxy.Proxy("localhost", port=port, port_end=port+1) prox = proxy.Proxy("localhost", port=port, port_end=port+1)
......
...@@ -2,7 +2,8 @@ import os ...@@ -2,7 +2,8 @@ import os
import numpy as np import numpy as np
import nnvm.compiler import nnvm.compiler
import tvm import tvm
from tvm.contrib import rpc, util, graph_runtime from tvm import rpc
from tvm.contrib import util, graph_runtime
def test_save_load(): def test_save_load():
......
import tvm import tvm
from tvm.contrib import util, rpc, graph_runtime from tvm import rpc
from tvm.contrib import util, graph_runtime
import nnvm.symbol as sym import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
import numpy as np import numpy as np
......
"""Minimum graph runtime that executes graph containing TVM PackedFunc.""" """Minimum graph runtime that executes graph containing TVM PackedFunc."""
from .._ffi.base import string_types from .._ffi.base import string_types
from .._ffi.function import get_global_func from .._ffi.function import get_global_func
from .rpc import base as rpc_base from ..rpc import base as rpc_base
from .. import ndarray as nd from .. import ndarray as nd
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
"""measure bandwidth and compute peak""" """measure bandwidth and compute peak"""
import logging import logging
import tvm import tvm
from tvm.contrib import rpc, util from . import util
from .. import rpc
def _convert_to_remote(func, remote): def _convert_to_remote(func, remote):
""" convert module function to remote rpc function""" """ convert module function to remote rpc function"""
...@@ -47,7 +47,7 @@ def measure_bandwidth_sum(total_item, item_per_thread, stride, ...@@ -47,7 +47,7 @@ def measure_bandwidth_sum(total_item, item_per_thread, stride,
host compilation target host compilation target
ctx: TVMcontext ctx: TVMcontext
the context of array the context of array
remote: tvm.contrib.rpc.RPCSession remote: tvm.rpc.RPCSession
remote rpc session remote rpc session
n_times: int n_times: int
number of runs for taking mean number of runs for taking mean
...@@ -107,7 +107,7 @@ def measure_bandwidth_all_types(total_item, item_per_thread, n_times, ...@@ -107,7 +107,7 @@ def measure_bandwidth_all_types(total_item, item_per_thread, n_times,
the target and option of the compilation. the target and option of the compilation.
target_host : str or :any:`tvm.target.Target` target_host : str or :any:`tvm.target.Target`
host compilation target host compilation target
remote: tvm.contrib.rpc.RPCSession remote: tvm.rpc.RPCSession
remote rpc session remote rpc session
ctx: TVMcontext ctx: TVMcontext
the context of array the context of array
...@@ -165,7 +165,7 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes, ...@@ -165,7 +165,7 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
the target and option of the compilation. the target and option of the compilation.
target_host : str or :any:`tvm.target.Target` target_host : str or :any:`tvm.target.Target`
host compilation target host compilation target
remote: tvm.contrib.rpc.RPCSession remote: tvm.rpc.RPCSession
if it is not None, use remote rpc session if it is not None, use remote rpc session
ctx: TVMcontext ctx: TVMcontext
the context of array the context of array
...@@ -250,7 +250,7 @@ def measure_compute_all_types(total_item, item_per_thread, n_times, ...@@ -250,7 +250,7 @@ def measure_compute_all_types(total_item, item_per_thread, n_times,
the target and option of the compilation. the target and option of the compilation.
target_host : str or :any:`tvm.target.Target` target_host : str or :any:`tvm.target.Target`
host compilation target host compilation target
remote: tvm.contrib.rpc.RPCSession remote: tvm.rpc.RPCSession
remote rpc session remote rpc session
ctx: TVMcontext ctx: TVMcontext
the context of array the context of array
......
"""Deprecation RPC module"""
# pylint: disable=unused-import
from __future__ import absolute_import as _abs
import warnings
from ..rpc import Server, RPCSession, LocalSession, TrackerSession, connect, connect_tracker
warnings.warn(
"Please use tvm.rpc instead of tvm.conrtib.rpc. tvm.contrib.rpc is going to be removed in 0.5",
DeprecationWarning)
...@@ -4,7 +4,7 @@ from __future__ import absolute_import ...@@ -4,7 +4,7 @@ from __future__ import absolute_import
import logging import logging
import argparse import argparse
import os import os
from ..contrib import rpc from .. import rpc
def main(): def main():
"""Main funciton""" """Main funciton"""
......
...@@ -7,7 +7,7 @@ import argparse ...@@ -7,7 +7,7 @@ import argparse
import multiprocessing import multiprocessing
import sys import sys
import os import os
from ..contrib.rpc.proxy import Proxy from ..rpc.proxy import Proxy
def find_example_resource(): def find_example_resource():
......
...@@ -6,7 +6,7 @@ import argparse ...@@ -6,7 +6,7 @@ import argparse
import multiprocessing import multiprocessing
import sys import sys
import logging import logging
from ..contrib import rpc from .. import rpc
def main(args): def main(args):
"""Main function""" """Main function"""
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
import argparse import argparse
import multiprocessing import multiprocessing
import sys import sys
from ..contrib.rpc.tracker import Tracker from ..rpc.tracker import Tracker
def main(args): def main(args):
......
...@@ -9,8 +9,8 @@ import struct ...@@ -9,8 +9,8 @@ import struct
import random import random
import logging import logging
from ..._ffi.function import _init_api from .._ffi.function import _init_api
from ..._ffi.base import py_str from .._ffi.base import py_str
# Magic header for RPC data plane # Magic header for RPC data plane
RPC_MAGIC = 0xff271 RPC_MAGIC = 0xff271
...@@ -158,5 +158,5 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False): ...@@ -158,5 +158,5 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
time.sleep(retry_period) time.sleep(retry_period)
# Still use tvm.contrib.rpc for the foreign functions # Still use tvm.rpc for the foreign functions
_init_api("tvm.contrib.rpc", "tvm.contrib.rpc.base") _init_api("tvm.rpc", "tvm.rpc.base")
...@@ -7,11 +7,11 @@ import struct ...@@ -7,11 +7,11 @@ import struct
import time import time
from . import base from . import base
from .. import util from ..contrib import util
from ..._ffi.base import TVMError from .._ffi.base import TVMError
from ..._ffi import function as function from .._ffi import function as function
from ..._ffi import ndarray as nd from .._ffi import ndarray as nd
from ...module import load as _load_module from ..module import load as _load_module
class RPCSession(object): class RPCSession(object):
...@@ -82,7 +82,7 @@ class RPCSession(object): ...@@ -82,7 +82,7 @@ class RPCSession(object):
if "upload" not in self._remote_funcs: if "upload" not in self._remote_funcs:
self._remote_funcs["upload"] = self.get_function( self._remote_funcs["upload"] = self.get_function(
"tvm.contrib.rpc.server.upload") "tvm.rpc.server.upload")
self._remote_funcs["upload"](target, blob) self._remote_funcs["upload"](target, blob)
def download(self, path): def download(self, path):
...@@ -100,7 +100,7 @@ class RPCSession(object): ...@@ -100,7 +100,7 @@ class RPCSession(object):
""" """
if "download" not in self._remote_funcs: if "download" not in self._remote_funcs:
self._remote_funcs["download"] = self.get_function( self._remote_funcs["download"] = self.get_function(
"tvm.contrib.rpc.server.download") "tvm.rpc.server.download")
return self._remote_funcs["download"](path) return self._remote_funcs["download"](path)
def load_module(self, path): def load_module(self, path):
......
...@@ -28,7 +28,7 @@ except ImportError as error_msg: ...@@ -28,7 +28,7 @@ except ImportError as error_msg:
from . import base from . import base
from .base import TrackerCode from .base import TrackerCode
from .server import _server_env from .server import _server_env
from ..._ffi.base import py_str from .._ffi.base import py_str
class ForwardHandler(object): class ForwardHandler(object):
......
...@@ -21,11 +21,11 @@ import subprocess ...@@ -21,11 +21,11 @@ import subprocess
import time import time
import sys import sys
from ..._ffi.function import register_func from .._ffi.function import register_func
from ..._ffi.base import py_str from .._ffi.base import py_str
from ..._ffi.libinfo import find_lib_path from .._ffi.libinfo import find_lib_path
from ...module import load as _load_module from ..module import load as _load_module
from .. import util from ..contrib import util
from . import base from . import base
from . base import TrackerCode from . base import TrackerCode
...@@ -36,11 +36,11 @@ def _server_env(load_library, logger): ...@@ -36,11 +36,11 @@ def _server_env(load_library, logger):
logger = logging.getLogger() logger = logging.getLogger()
# pylint: disable=unused-variable # pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath") @register_func("tvm.rpc.server.workpath")
def get_workpath(path): def get_workpath(path):
return temp.relpath(path) return temp.relpath(path)
@register_func("tvm.contrib.rpc.server.load_module", override=True) @register_func("tvm.rpc.server.load_module", override=True)
def load_module(file_name): def load_module(file_name):
"""Load module from remote side.""" """Load module from remote side."""
path = temp.relpath(file_name) path = temp.relpath(file_name)
......
...@@ -39,7 +39,7 @@ except ImportError as error_msg: ...@@ -39,7 +39,7 @@ except ImportError as error_msg:
raise ImportError( raise ImportError(
"RPCTracker module requires tornado package %s" % error_msg) "RPCTracker module requires tornado package %s" % error_msg)
from ..._ffi.base import py_str from .._ffi.base import py_str
from . import base from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode from .base import RPC_TRACKER_MAGIC, TrackerCode
......
...@@ -20,7 +20,7 @@ void Module::Import(Module other) { ...@@ -20,7 +20,7 @@ void Module::Import(Module other) {
if (!std::strcmp((*this)->type_key(), "rpc")) { if (!std::strcmp((*this)->type_key(), "rpc")) {
static const PackedFunc* fimport_ = nullptr; static const PackedFunc* fimport_ = nullptr;
if (fimport_ == nullptr) { if (fimport_ == nullptr) {
fimport_ = runtime::Registry::Get("contrib.rpc._ImportRemoteModule"); fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
CHECK(fimport_ != nullptr); CHECK(fimport_ != nullptr);
} }
(*fimport_)(*this, other); (*fimport_)(*this, other);
......
...@@ -44,7 +44,7 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, ...@@ -44,7 +44,7 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend,
}); });
} }
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer") TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEventDrivenServer(args[0], args[1], args[2]); *rv = CreateEventDrivenServer(args[0], args[1], args[2]);
}); });
......
...@@ -165,7 +165,7 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator") ...@@ -165,7 +165,7 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
} }
}); });
TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule") TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0]; Module m = args[0];
std::string tkey = m->type_key(); std::string tkey = m->type_key();
...@@ -177,7 +177,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule") ...@@ -177,7 +177,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
*rv = Module(n); *rv = Module(n);
}); });
TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule") TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module parent = args[0]; Module parent = args[0];
Module child = args[1]; Module child = args[1];
...@@ -192,7 +192,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule") ...@@ -192,7 +192,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule")
cmod->module_handle()); cmod->module_handle());
}); });
TVM_REGISTER_GLOBAL("contrib.rpc._ModuleHandle") TVM_REGISTER_GLOBAL("rpc._ModuleHandle")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0]; Module m = args[0];
std::string tkey = m->type_key(); std::string tkey = m->type_key();
...@@ -200,7 +200,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._ModuleHandle") ...@@ -200,7 +200,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._ModuleHandle")
*rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle(); *rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle();
}); });
TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex") TVM_REGISTER_GLOBAL("rpc._SessTableIndex")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0]; Module m = args[0];
std::string tkey = m->type_key(); std::string tkey = m->type_key();
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file rpc_server_env * \file rpc_server_env.cc
* \brief Server environment of the RPC. * \brief Server environment of the RPC.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -11,19 +11,19 @@ namespace runtime { ...@@ -11,19 +11,19 @@ namespace runtime {
std::string RPCGetPath(const std::string& name) { std::string RPCGetPath(const std::string& name) {
static const PackedFunc* f = static const PackedFunc* f =
runtime::Registry::Get("tvm.contrib.rpc.server.workpath"); runtime::Registry::Get("tvm.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath"; CHECK(f != nullptr) << "require tvm.rpc.server.workpath";
return (*f)(name); return (*f)(name);
} }
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload"). TVM_REGISTER_GLOBAL("tvm.rpc.server.upload").
set_body([](TVMArgs args, TVMRetValue *rv) { set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]); std::string file_name = RPCGetPath(args[0]);
std::string data = args[1]; std::string data = args[1];
SaveBinaryToFile(file_name, data); SaveBinaryToFile(file_name, data);
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.download") TVM_REGISTER_GLOBAL("tvm.rpc.server.download")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]); std::string file_name = RPCGetPath(args[0]);
std::string data; std::string data;
......
...@@ -850,12 +850,12 @@ void RPCSession::Shutdown() { ...@@ -850,12 +850,12 @@ void RPCSession::Shutdown() {
void RPCSession::ServerLoop() { void RPCSession::ServerLoop() {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (const auto* f = Registry::Get("tvm.contrib.rpc.server.start")) { if (const auto* f = Registry::Get("tvm.rpc.server.start")) {
(*f)(); (*f)();
} }
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown); CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
if (const auto* f = Registry::Get("tvm.contrib.rpc.server.shutdown")) { if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) {
(*f)(); (*f)();
} }
channel_.reset(nullptr); channel_.reset(nullptr);
...@@ -1046,7 +1046,7 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { ...@@ -1046,7 +1046,7 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
static const PackedFunc* fsys_load_ = nullptr; static const PackedFunc* fsys_load_ = nullptr;
if (fsys_load_ == nullptr) { if (fsys_load_ == nullptr) {
fsys_load_ = runtime::Registry::Get("tvm.contrib.rpc.server.load_module"); fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module");
CHECK(fsys_load_ != nullptr); CHECK(fsys_load_ != nullptr);
} }
std::string file_name = args[0]; std::string file_name = args[0];
......
...@@ -90,12 +90,12 @@ void RPCServerLoop(int sockfd) { ...@@ -90,12 +90,12 @@ void RPCServerLoop(int sockfd) {
"SockServerLoop", "")->ServerLoop(); "SockServerLoop", "")->ServerLoop();
} }
TVM_REGISTER_GLOBAL("contrib.rpc._Connect") TVM_REGISTER_GLOBAL("rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCClientConnect(args[0], args[1], args[2]); *rv = RPCClientConnect(args[0], args[1], args[2]);
}); });
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop") TVM_REGISTER_GLOBAL("rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]); RPCServerLoop(args[0]);
}); });
......
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import numpy as np import numpy as np
import time import time
import multiprocessing import multiprocessing
from tvm.contrib import rpc from tvm import rpc
def rpc_proxy_check(): def rpc_proxy_check():
"""This is a simple test function for RPC Proxy """This is a simple test function for RPC Proxy
...@@ -17,7 +17,7 @@ def rpc_proxy_check(): ...@@ -17,7 +17,7 @@ def rpc_proxy_check():
""" """
try: try:
from tvm.contrib.rpc import proxy from tvm.rpc import proxy
web_port = 8888 web_port = 8888
prox = proxy.Proxy("localhost", web_port=web_port) prox = proxy.Proxy("localhost", web_port=web_port)
def check(): def check():
......
...@@ -3,13 +3,13 @@ import logging ...@@ -3,13 +3,13 @@ import logging
import numpy as np import numpy as np
import time import time
import multiprocessing import multiprocessing
from tvm.contrib import rpc from tvm import rpc
def check_server_drop(): def check_server_drop():
"""test when server drops""" """test when server drops"""
try: try:
from tvm.contrib.rpc import tracker, proxy, base from tvm.rpc import tracker, proxy, base
from tvm.contrib.rpc.base import TrackerCode from tvm.rpc.base import TrackerCode
@tvm.register_func("rpc.test2.addone") @tvm.register_func("rpc.test2.addone")
def addone(x): def addone(x):
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import tvm import tvm
import os import os
import struct import struct
from tvm.contrib import util, cc, rpc from tvm import rpc
from tvm.contrib import util, cc
import numpy as np import numpy as np
def test_llvm_add_pipeline(): def test_llvm_add_pipeline():
......
import tvm import tvm
import numpy as np import numpy as np
import json import json
from tvm.contrib import rpc, util, graph_runtime from tvm import rpc
from tvm.contrib import util, graph_runtime
def test_graph_simple(): def test_graph_simple():
n = 4 n = 4
......
...@@ -3,7 +3,8 @@ import os ...@@ -3,7 +3,8 @@ import os
import logging import logging
import numpy as np import numpy as np
import time import time
from tvm.contrib import rpc, util from tvm import rpc
from tvm.contrib import util
def test_bigendian_rpc(): def test_bigendian_rpc():
......
...@@ -6,7 +6,8 @@ Connect javascript end to the websocket port and connect to the RPC. ...@@ -6,7 +6,8 @@ Connect javascript end to the websocket port and connect to the RPC.
import tvm import tvm
import os import os
from tvm.contrib import rpc, util, emscripten from tvm import rpc
from tvm.contrib import util, emscripten
import numpy as np import numpy as np
proxy_host = "localhost" proxy_host = "localhost"
......
import numpy as np import numpy as np
import tvm import tvm
from tvm.contrib import rpc, util, emscripten from tvm import rpc
from tvm.contrib import util, emscripten
def test_local_save_load(): def test_local_save_load():
if not tvm.module.enabled("opengl"): if not tvm.module.enabled("opengl"):
......
...@@ -14,7 +14,8 @@ $ python tests/webgl/test_remote_save_load.py ...@@ -14,7 +14,8 @@ $ python tests/webgl/test_remote_save_load.py
import numpy as np import numpy as np
import tvm import tvm
from tvm.contrib import rpc, util, emscripten from tvm import rpc
from tvm.contrib import util, emscripten
proxy_host = "localhost" proxy_host = "localhost"
proxy_port = 9090 proxy_port = 9090
......
"""Example code to do square matrix multiplication on Android Phone.""" """Example code to do square matrix multiplication on Android Phone."""
import tvm import tvm
import os import os
from tvm.contrib import rpc, util, ndk from tvm import rpc
from tvm.contrib import util, ndk
import numpy as np import numpy as np
# Set to be address of tvm proxy. # Set to be address of tvm proxy.
proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"] proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"]
proxy_port = 9090 proxy_port = 9090
key = "android" key = "android"
# Change target configuration. # Change target configuration.
# Run `adb shell cat /proc/cpuinfo` to find the arch. # Run `adb shell cat /proc/cpuinfo` to find the arch.
arch = "arm64" arch = "arm64"
...@@ -100,7 +101,7 @@ def test_gemm_gpu(N, times, bn, num_block, num_thread): ...@@ -100,7 +101,7 @@ def test_gemm_gpu(N, times, bn, num_block, num_thread):
print(tvm.lower(s, [A, B, C], simple_mode=True)) print(tvm.lower(s, [A, B, C], simple_mode=True))
f = tvm.build(s, [A, B, C], "opencl", target_host=target, name="gemm_gpu") f = tvm.build(s, [A, B, C], "opencl", target_host=target, name="gemm_gpu")
temp = util.tempdir() temp = util.tempdir()
path_dso = temp.relpath("gemm_gpu.so") path_dso = temp.relpath("gemm_gpu.so")
f.export_library(path_dso, ndk.create_shared) f.export_library(path_dso, ndk.create_shared)
......
...@@ -106,7 +106,8 @@ from __future__ import absolute_import, print_function ...@@ -106,7 +106,8 @@ from __future__ import absolute_import, print_function
import tvm import tvm
import numpy as np import numpy as np
from tvm.contrib import rpc, util from tvm import rpc
from tvm.contrib import util
###################################################################### ######################################################################
# Declare and Cross Compile Kernel on Local Machine # Declare and Cross Compile Kernel on Local Machine
......
...@@ -15,9 +15,8 @@ To begin with, we import nnvm (for compilation) and TVM (for deployment). ...@@ -15,9 +15,8 @@ To begin with, we import nnvm (for compilation) and TVM (for deployment).
import tvm import tvm
import nnvm.compiler import nnvm.compiler
import nnvm.testing import nnvm.testing
from tvm.contrib import util, rpc from tvm import rpc
from tvm.contrib import graph_runtime as runtime from tvm.contrib import util, graph_runtime as runtime
###################################################################### ######################################################################
# Build TVM Runtime on Device # Build TVM Runtime on Device
......
...@@ -11,9 +11,8 @@ To begin with, we import nnvm(for compilation) and TVM(for deployment). ...@@ -11,9 +11,8 @@ To begin with, we import nnvm(for compilation) and TVM(for deployment).
import tvm import tvm
import nnvm.compiler import nnvm.compiler
import nnvm.testing import nnvm.testing
from tvm.contrib import util, rpc from tvm import rpc
from tvm.contrib import graph_runtime as runtime from tvm.contrib import util, graph_runtime as runtime
###################################################################### ######################################################################
# Build TVM Runtime on Device # Build TVM Runtime on Device
......
...@@ -326,7 +326,8 @@ if run_deploy_local and opengl_enabled: ...@@ -326,7 +326,8 @@ if run_deploy_local and opengl_enabled:
def deploy_rpc(): def deploy_rpc():
"""Runs the demo that deploys a model remotely through RPC. """Runs the demo that deploys a model remotely through RPC.
""" """
from tvm.contrib import rpc, util, emscripten from tvm import rpc
from tvm.contrib import util, emscripten
# As usual, load the resnet18 model. # As usual, load the resnet18 model.
net, params, data_shape, out_shape = load_mxnet_resnet() net, params, data_shape, out_shape = load_mxnet_resnet()
......
...@@ -133,7 +133,7 @@ print(out.asnumpy()[0][0:10]) ...@@ -133,7 +133,7 @@ print(out.asnumpy()[0][0:10])
# `llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon` # `llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon`
# is the recommended compilation configuration, thanks to Ziheng's work. # is the recommended compilation configuration, thanks to Ziheng's work.
from tvm.contrib import rpc from tvm import rpc
use_rasp = False use_rasp = False
host = 'rasp0' host = 'rasp0'
......
...@@ -882,7 +882,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -882,7 +882,7 @@ var tvm_runtime = tvm_runtime || {};
if (typeof systemFunc.fcreateServer === "undefined") { if (typeof systemFunc.fcreateServer === "undefined") {
systemFunc.fcreateServer = systemFunc.fcreateServer =
getGlobalFunc("contrib.rpc._CreateEventDrivenServer"); getGlobalFunc("rpc._CreateEventDrivenServer");
} }
if (systemFunc.fcreateServer == null) { if (systemFunc.fcreateServer == null) {
throwError("RPCServer is not included in runtime"); throwError("RPCServer is not included in runtime");
......
...@@ -39,13 +39,13 @@ struct RPCEnv { ...@@ -39,13 +39,13 @@ struct RPCEnv {
std::string base_; std::string base_;
}; };
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath") TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env; static RPCEnv env;
*rv = env.GetPath(args[0]); *rv = env.GetPath(args[0]);
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module") TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = "/rpc/" + args[0].operator std::string(); std::string file_name = "/rpc/" + args[0].operator std::string();
*rv = Module::LoadFromFile(file_name, ""); *rv = Module::LoadFromFile(file_name, "");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment