Commit 81db22c5 by Tianqi Chen Committed by GitHub

[RPC] graduate tvm.contrib.rpc -> tvm.rpc (#1410)

parent 68e4a111
......@@ -6,7 +6,8 @@ And configure the proxy host field as commented.
import tvm
import os
from tvm.contrib import rpc, util, ndk
from tvm import rpc
from tvm.contrib import util, ndk
import numpy as np
# Set to be address of tvm proxy.
......
......@@ -6,7 +6,8 @@ And configure the proxy host field as commented.
import tvm
import os
from tvm.contrib import rpc, util, xcode
from tvm import rpc
from tvm.contrib import util, xcode
import numpy as np
# Set to be address of tvm proxy.
......
......@@ -103,13 +103,13 @@ void LaunchSyncServer() {
->ServerLoop();
}
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string name = args[0];
std::string fmt = GetFileFormat(name, "");
......
tvm.contrib.rpc
---------------
.. automodule:: tvm.contrib.rpc
tvm.rpc
-------
.. automodule:: tvm.rpc
.. autofunction:: tvm.contrib.rpc.connect
.. autofunction:: tvm.contrib.rpc.connect_tracker
.. autofunction:: tvm.rpc.connect
.. autofunction:: tvm.rpc.connect_tracker
.. autoclass:: tvm.contrib.rpc.TrackerSession
.. autoclass:: tvm.rpc.TrackerSession
:members:
:inherited-members:
.. autoclass:: tvm.contrib.rpc.RPCSession
.. autoclass:: tvm.rpc.RPCSession
:members:
:inherited-members:
.. autoclass:: tvm.contrib.rpc.LocalSession
.. autoclass:: tvm.rpc.LocalSession
:members:
:inherited-members:
.. autoclass:: tvm.contrib.rpc.Server
.. autoclass:: tvm.rpc.Server
:members:
:inherited-members:
......@@ -19,10 +19,10 @@ We can then use the following command to launch a `tvmai/demo-cpu` image.
.. code:: bash
/path/to/tvm/docker/bash.sh tvmai/demo_cpu
/path/to/tvm/docker/bash.sh tvmai/demo-cpu
.. note::
You can find all the prebuilt images in `<https://hub.docker.com/r/tvmai/>`_
You can also change `demo-cpu` to `demo-gpu` to get a CUDA enabled image.
You can find all the prebuilt images in `<https://hub.docker.com/r/tvmai/>`_
This auxiliary script does the following things:
......
......@@ -70,13 +70,13 @@ public class NativeServerLoop implements Runnable {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() {
Function.register("tvm.rpc.server.workpath", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return tempDir + File.separator + args[0].asString();
}
}, true);
Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() {
Function.register("tvm.rpc.server.load_module", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
String filename = args[0].asString();
String path = tempDir + File.separator + filename;
......
......@@ -37,7 +37,7 @@ public class RPC {
static Function getApi(String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction("contrib.rpc." + name);
func = Function.getFunction("rpc." + name);
if (func == null) {
return null;
}
......
......@@ -172,7 +172,7 @@ public class RPCSession {
final String funcName = "upload";
Function remoteFunc = remoteFuncs.get(funcName);
if (remoteFunc == null) {
remoteFunc = getFunction("tvm.contrib.rpc.server.upload");
remoteFunc = getFunction("tvm.rpc.server.upload");
remoteFuncs.put(funcName, remoteFunc);
}
remoteFunc.pushArg(target).pushArg(data).invoke();
......@@ -205,7 +205,7 @@ public class RPCSession {
final String name = "download";
Function func = remoteFuncs.get(name);
if (func == null) {
func = getFunction("tvm.contrib.rpc.server.download");
func = getFunction("tvm.rpc.server.download");
remoteFuncs.put(name, func);
}
return func.pushArg(path).invoke().asBytes();
......
import time
from tvm.contrib.rpc import proxy
from tvm.rpc import proxy
def start_proxy_server(port, timeout):
prox = proxy.Proxy("localhost", port=port, port_end=port+1)
......
......@@ -2,7 +2,8 @@ import os
import numpy as np
import nnvm.compiler
import tvm
from tvm.contrib import rpc, util, graph_runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime
def test_save_load():
......
import tvm
from tvm.contrib import util, rpc, graph_runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime
import nnvm.symbol as sym
import nnvm.compiler
import numpy as np
......
"""Minimum graph runtime that executes graph containing TVM PackedFunc."""
from .._ffi.base import string_types
from .._ffi.function import get_global_func
from .rpc import base as rpc_base
from ..rpc import base as rpc_base
from .. import ndarray as nd
......
......@@ -2,9 +2,9 @@
"""measure bandwidth and compute peak"""
import logging
import tvm
from tvm.contrib import rpc, util
from . import util
from .. import rpc
def _convert_to_remote(func, remote):
""" convert module function to remote rpc function"""
......@@ -47,7 +47,7 @@ def measure_bandwidth_sum(total_item, item_per_thread, stride,
host compilation target
ctx: TVMcontext
the context of array
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
n_times: int
number of runs for taking mean
......@@ -107,7 +107,7 @@ def measure_bandwidth_all_types(total_item, item_per_thread, n_times,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
ctx: TVMcontext
the context of array
......@@ -165,7 +165,7 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
if it is not None, use remote rpc session
ctx: TVMcontext
the context of array
......@@ -250,7 +250,7 @@ def measure_compute_all_types(total_item, item_per_thread, n_times,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
ctx: TVMcontext
the context of array
......
"""Deprecation RPC module"""
# pylint: disable=unused-import
from __future__ import absolute_import as _abs
import warnings
from ..rpc import Server, RPCSession, LocalSession, TrackerSession, connect, connect_tracker
warnings.warn(
"Please use tvm.rpc instead of tvm.conrtib.rpc. tvm.contrib.rpc is going to be removed in 0.5",
DeprecationWarning)
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
import logging
import argparse
import os
from ..contrib import rpc
from .. import rpc
def main():
"""Main funciton"""
......
......@@ -7,7 +7,7 @@ import argparse
import multiprocessing
import sys
import os
from ..contrib.rpc.proxy import Proxy
from ..rpc.proxy import Proxy
def find_example_resource():
......
......@@ -6,7 +6,7 @@ import argparse
import multiprocessing
import sys
import logging
from ..contrib import rpc
from .. import rpc
def main(args):
"""Main function"""
......
......@@ -6,7 +6,7 @@ import logging
import argparse
import multiprocessing
import sys
from ..contrib.rpc.tracker import Tracker
from ..rpc.tracker import Tracker
def main(args):
......
......@@ -9,8 +9,8 @@ import struct
import random
import logging
from ..._ffi.function import _init_api
from ..._ffi.base import py_str
from .._ffi.function import _init_api
from .._ffi.base import py_str
# Magic header for RPC data plane
RPC_MAGIC = 0xff271
......@@ -158,5 +158,5 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
time.sleep(retry_period)
# Still use tvm.contrib.rpc for the foreign functions
_init_api("tvm.contrib.rpc", "tvm.contrib.rpc.base")
# Still use tvm.rpc for the foreign functions
_init_api("tvm.rpc", "tvm.rpc.base")
......@@ -7,11 +7,11 @@ import struct
import time
from . import base
from .. import util
from ..._ffi.base import TVMError
from ..._ffi import function as function
from ..._ffi import ndarray as nd
from ...module import load as _load_module
from ..contrib import util
from .._ffi.base import TVMError
from .._ffi import function as function
from .._ffi import ndarray as nd
from ..module import load as _load_module
class RPCSession(object):
......@@ -82,7 +82,7 @@ class RPCSession(object):
if "upload" not in self._remote_funcs:
self._remote_funcs["upload"] = self.get_function(
"tvm.contrib.rpc.server.upload")
"tvm.rpc.server.upload")
self._remote_funcs["upload"](target, blob)
def download(self, path):
......@@ -100,7 +100,7 @@ class RPCSession(object):
"""
if "download" not in self._remote_funcs:
self._remote_funcs["download"] = self.get_function(
"tvm.contrib.rpc.server.download")
"tvm.rpc.server.download")
return self._remote_funcs["download"](path)
def load_module(self, path):
......
......@@ -28,7 +28,7 @@ except ImportError as error_msg:
from . import base
from .base import TrackerCode
from .server import _server_env
from ..._ffi.base import py_str
from .._ffi.base import py_str
class ForwardHandler(object):
......
......@@ -21,11 +21,11 @@ import subprocess
import time
import sys
from ..._ffi.function import register_func
from ..._ffi.base import py_str
from ..._ffi.libinfo import find_lib_path
from ...module import load as _load_module
from .. import util
from .._ffi.function import register_func
from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module
from ..contrib import util
from . import base
from . base import TrackerCode
......@@ -36,11 +36,11 @@ def _server_env(load_library, logger):
logger = logging.getLogger()
# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath")
@register_func("tvm.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)
@register_func("tvm.contrib.rpc.server.load_module", override=True)
@register_func("tvm.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
......
......@@ -39,7 +39,7 @@ except ImportError as error_msg:
raise ImportError(
"RPCTracker module requires tornado package %s" % error_msg)
from ..._ffi.base import py_str
from .._ffi.base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode
......
......@@ -20,7 +20,7 @@ void Module::Import(Module other) {
if (!std::strcmp((*this)->type_key(), "rpc")) {
static const PackedFunc* fimport_ = nullptr;
if (fimport_ == nullptr) {
fimport_ = runtime::Registry::Get("contrib.rpc._ImportRemoteModule");
fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
CHECK(fimport_ != nullptr);
}
(*fimport_)(*this, other);
......
......@@ -44,7 +44,7 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend,
});
}
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEventDrivenServer(args[0], args[1], args[2]);
});
......
......@@ -165,7 +165,7 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
}
});
TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
......@@ -177,7 +177,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
*rv = Module(n);
});
TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule")
TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module parent = args[0];
Module child = args[1];
......@@ -192,7 +192,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule")
cmod->module_handle());
});
TVM_REGISTER_GLOBAL("contrib.rpc._ModuleHandle")
TVM_REGISTER_GLOBAL("rpc._ModuleHandle")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
......@@ -200,7 +200,7 @@ TVM_REGISTER_GLOBAL("contrib.rpc._ModuleHandle")
*rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle();
});
TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
TVM_REGISTER_GLOBAL("rpc._SessTableIndex")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_server_env
* \file rpc_server_env.cc
* \brief Server environment of the RPC.
*/
#include <tvm/runtime/registry.h>
......@@ -11,19 +11,19 @@ namespace runtime {
std::string RPCGetPath(const std::string& name) {
static const PackedFunc* f =
runtime::Registry::Get("tvm.contrib.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath";
runtime::Registry::Get("tvm.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.rpc.server.workpath";
return (*f)(name);
}
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
TVM_REGISTER_GLOBAL("tvm.rpc.server.upload").
set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data = args[1];
SaveBinaryToFile(file_name, data);
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.download")
TVM_REGISTER_GLOBAL("tvm.rpc.server.download")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data;
......
......@@ -850,12 +850,12 @@ void RPCSession::Shutdown() {
void RPCSession::ServerLoop() {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (const auto* f = Registry::Get("tvm.contrib.rpc.server.start")) {
if (const auto* f = Registry::Get("tvm.rpc.server.start")) {
(*f)();
}
TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
if (const auto* f = Registry::Get("tvm.contrib.rpc.server.shutdown")) {
if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) {
(*f)();
}
channel_.reset(nullptr);
......@@ -1046,7 +1046,7 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
static const PackedFunc* fsys_load_ = nullptr;
if (fsys_load_ == nullptr) {
fsys_load_ = runtime::Registry::Get("tvm.contrib.rpc.server.load_module");
fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module");
CHECK(fsys_load_ != nullptr);
}
std::string file_name = args[0];
......
......@@ -90,12 +90,12 @@ void RPCServerLoop(int sockfd) {
"SockServerLoop", "")->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
TVM_REGISTER_GLOBAL("rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCClientConnect(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
TVM_REGISTER_GLOBAL("rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
......
......@@ -3,7 +3,7 @@ import logging
import numpy as np
import time
import multiprocessing
from tvm.contrib import rpc
from tvm import rpc
def rpc_proxy_check():
"""This is a simple test function for RPC Proxy
......@@ -17,7 +17,7 @@ def rpc_proxy_check():
"""
try:
from tvm.contrib.rpc import proxy
from tvm.rpc import proxy
web_port = 8888
prox = proxy.Proxy("localhost", web_port=web_port)
def check():
......
......@@ -3,13 +3,13 @@ import logging
import numpy as np
import time
import multiprocessing
from tvm.contrib import rpc
from tvm import rpc
def check_server_drop():
"""test when server drops"""
try:
from tvm.contrib.rpc import tracker, proxy, base
from tvm.contrib.rpc.base import TrackerCode
from tvm.rpc import tracker, proxy, base
from tvm.rpc.base import TrackerCode
@tvm.register_func("rpc.test2.addone")
def addone(x):
......
......@@ -2,7 +2,8 @@
import tvm
import os
import struct
from tvm.contrib import util, cc, rpc
from tvm import rpc
from tvm.contrib import util, cc
import numpy as np
def test_llvm_add_pipeline():
......
import tvm
import numpy as np
import json
from tvm.contrib import rpc, util, graph_runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime
def test_graph_simple():
n = 4
......
......@@ -3,7 +3,8 @@ import os
import logging
import numpy as np
import time
from tvm.contrib import rpc, util
from tvm import rpc
from tvm.contrib import util
def test_bigendian_rpc():
......
......@@ -6,7 +6,8 @@ Connect javascript end to the websocket port and connect to the RPC.
import tvm
import os
from tvm.contrib import rpc, util, emscripten
from tvm import rpc
from tvm.contrib import util, emscripten
import numpy as np
proxy_host = "localhost"
......
import numpy as np
import tvm
from tvm.contrib import rpc, util, emscripten
from tvm import rpc
from tvm.contrib import util, emscripten
def test_local_save_load():
if not tvm.module.enabled("opengl"):
......
......@@ -14,7 +14,8 @@ $ python tests/webgl/test_remote_save_load.py
import numpy as np
import tvm
from tvm.contrib import rpc, util, emscripten
from tvm import rpc
from tvm.contrib import util, emscripten
proxy_host = "localhost"
proxy_port = 9090
......
"""Example code to do square matrix multiplication on Android Phone."""
import tvm
import os
from tvm.contrib import rpc, util, ndk
from tvm import rpc
from tvm.contrib import util, ndk
import numpy as np
# Set to be address of tvm proxy.
proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"]
proxy_port = 9090
key = "android"
# Change target configuration.
# Run `adb shell cat /proc/cpuinfo` to find the arch.
arch = "arm64"
......@@ -100,7 +101,7 @@ def test_gemm_gpu(N, times, bn, num_block, num_thread):
print(tvm.lower(s, [A, B, C], simple_mode=True))
f = tvm.build(s, [A, B, C], "opencl", target_host=target, name="gemm_gpu")
temp = util.tempdir()
temp = util.tempdir()
path_dso = temp.relpath("gemm_gpu.so")
f.export_library(path_dso, ndk.create_shared)
......
......@@ -106,7 +106,8 @@ from __future__ import absolute_import, print_function
import tvm
import numpy as np
from tvm.contrib import rpc, util
from tvm import rpc
from tvm.contrib import util
######################################################################
# Declare and Cross Compile Kernel on Local Machine
......
......@@ -15,9 +15,8 @@ To begin with, we import nnvm (for compilation) and TVM (for deployment).
import tvm
import nnvm.compiler
import nnvm.testing
from tvm.contrib import util, rpc
from tvm.contrib import graph_runtime as runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime as runtime
######################################################################
# Build TVM Runtime on Device
......
......@@ -11,9 +11,8 @@ To begin with, we import nnvm(for compilation) and TVM(for deployment).
import tvm
import nnvm.compiler
import nnvm.testing
from tvm.contrib import util, rpc
from tvm.contrib import graph_runtime as runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime as runtime
######################################################################
# Build TVM Runtime on Device
......
......@@ -326,7 +326,8 @@ if run_deploy_local and opengl_enabled:
def deploy_rpc():
"""Runs the demo that deploys a model remotely through RPC.
"""
from tvm.contrib import rpc, util, emscripten
from tvm import rpc
from tvm.contrib import util, emscripten
# As usual, load the resnet18 model.
net, params, data_shape, out_shape = load_mxnet_resnet()
......
......@@ -133,7 +133,7 @@ print(out.asnumpy()[0][0:10])
# `llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon`
# is the recommended compilation configuration, thanks to Ziheng's work.
from tvm.contrib import rpc
from tvm import rpc
use_rasp = False
host = 'rasp0'
......
......@@ -882,7 +882,7 @@ var tvm_runtime = tvm_runtime || {};
if (typeof systemFunc.fcreateServer === "undefined") {
systemFunc.fcreateServer =
getGlobalFunc("contrib.rpc._CreateEventDrivenServer");
getGlobalFunc("rpc._CreateEventDrivenServer");
}
if (systemFunc.fcreateServer == null) {
throwError("RPCServer is not included in runtime");
......
......@@ -39,13 +39,13 @@ struct RPCEnv {
std::string base_;
};
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = "/rpc/" + args[0].operator std::string();
*rv = Module::LoadFromFile(file_name, "");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment