Commit 51c40b4f by Tianqi Chen Committed by GitHub

[CODEGEN] Enable cross compile of AMDGPU without rocm, update rpc (#1154)

parent 11c7b6cf
......@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
def _connect(key):
conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn)
temp = _server_env()
temp = _server_env(None)
# Start connecton
conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True)
key = "server:" + key
......
......@@ -11,6 +11,7 @@ Server is TCP based with the following protocol:
from __future__ import absolute_import
import os
import ctypes
import socket
import select
import struct
......@@ -21,12 +22,13 @@ import time
from ..._ffi.function import register_func
from ..._ffi.base import py_str
from ..._ffi.libinfo import find_lib_path
from ...module import load as _load_module
from .. import util
from . import base
from . base import TrackerCode
def _server_env():
def _server_env(load_library):
"""Server environment function return temp dir"""
temp = util.tempdir()
# pylint: disable=unused-variable
......@@ -41,13 +43,21 @@ def _server_env():
m = _load_module(path)
logging.info("load_module %s", path)
return m
libs = []
load_library = load_library.split(":") if load_library else []
for file_name in load_library:
file_name = find_lib_path(file_name)[0]
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
temp.libs = libs
return temp
def _serve_loop(sock, addr):
def _serve_loop(sock, addr, load_library):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env()
temp = _server_env(load_library)
base._ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
......@@ -62,7 +72,7 @@ def _parse_server_opt(opts):
return ret
def _listen_loop(sock, port, rpc_key, tracker_addr):
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
"""Lisenting loop of the server master."""
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.
......@@ -162,7 +172,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# step 3: serving
logging.info("RPCServer: connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr, load_library))
server_proc.deamon = True
server_proc.start()
# close from our side.
......@@ -174,7 +184,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
server_proc.terminate()
def _connect_proxy_loop(addr, key):
def _connect_proxy_loop(addr, key, load_library):
key = "server:" + key
retry_count = 0
max_retry = 5
......@@ -198,7 +208,7 @@ def _connect_proxy_loop(addr, key):
opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr))
process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr))
target=_serve_loop, args=(sock, addr, load_library))
process.deamon = True
process.start()
sock.close()
......@@ -256,6 +266,9 @@ class Server(object):
key : str, optional
The key used to identify the server in Proxy connection.
load_library : str, optional
List of additional libraries to be loaded during execution.
"""
def __init__(self,
host,
......@@ -264,7 +277,8 @@ class Server(object):
is_proxy=False,
use_popen=False,
tracker_addr=None,
key=""):
key="",
load_library=None):
try:
if base._ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
......@@ -283,6 +297,8 @@ class Server(object):
assert key
cmd += ["--tracker=%s:%d" % tracker_addr,
"--key=%s" % key]
if load_library:
cmd += ["--load-libary", load_library]
self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True
......@@ -308,12 +324,12 @@ class Server(object):
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr))
self.sock, self.port, key, tracker_addr, load_library))
self.proc.deamon = True
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key))
target=_connect_proxy_loop, args=((host, port), key, load_library))
self.proc.deamon = True
self.proc.start()
......
"""Start an RPC server"""
from __future__ import absolute_import
import logging
import argparse
import os
import ctypes
from ..contrib import rpc
from .._ffi.libinfo import find_lib_path
def main():
"""Main funciton"""
......@@ -19,26 +15,12 @@ def main():
help='The end search port of the PRC')
parser.add_argument('--key', type=str, default="",
help="RPC key used to identify the connection type.")
parser.add_argument('--with-executor', type=bool, default=False,
help="Whether to load executor runtime")
parser.add_argument('--load-library', type=str, default="",
help="Additional library to load")
parser.add_argument('--tracker', type=str, default="",
help="Report to RPC tracker")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
load_library = [lib for lib in args.load_library.split(":") if len(lib) != 0]
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
apps_path = os.path.join(curr_path, "../../../apps/graph_executor/lib/")
libs = []
if args.with_executor:
load_library += ["libtvm_graph_exec.so"]
for file_name in load_library:
file_name = find_lib_path(file_name, apps_path)[0]
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
if args.tracker:
url, port = args.tracker.split(":")
port = int(port)
......@@ -53,8 +35,8 @@ def main():
args.port,
args.port_end,
key=args.key,
tracker_addr=tracker_addr)
server.libs += libs
tracker_addr=tracker_addr,
load_library=args.load_library)
server.proc.join()
if __name__ == "__main__":
......
......@@ -10,6 +10,7 @@
#include <tvm/codegen.h>
#include <string>
#include <vector>
#include <functional>
#include <unordered_map>
#include "../runtime/meta_data.h"
......@@ -111,17 +112,19 @@ class CodeGenSourceBase {
runtime::Module SourceModuleCreate(std::string code, std::string fmt);
/*!
* \brief Create a source module for viewing and limited saving
* \param code The code to be viewed.
* \brief Create a source module for viewing and limited saving for device.
* \param data The code data to be viewed.
* \param fmt The code. format.
* \param fmap The map function information map of each function.
* \param type_key The type_key of the runtime module of this source code
* \param fget_source a closure to replace default get source behavior.
*/
runtime::Module DeviceSourceModuleCreate(
std::string code,
std::string data,
std::string fmt,
std::unordered_map<std::string, runtime::FunctionInfo> fmap,
std::string type_key);
std::string type_key,
std::function<std::string(const std::string&)> fget_source = nullptr);
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
......@@ -4,15 +4,18 @@
* \brief AMDGPU code generator.
*/
#ifdef TVM_LLVM_VERSION
#if TVM_ROCM_RUNTIME
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include "./codegen_llvm.h"
#include "../build_common.h"
#include "../codegen_source_base.h"
#include "../../pass/ir_util.h"
#if TVM_ROCM_RUNTIME
#include "../../runtime/rocm/rocm_module.h"
#endif // TVM_ROCM_RUNTIME
namespace tvm {
namespace codegen {
......@@ -131,19 +134,27 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
};
inline int DetectROCMComputeVersion() {
inline int DetectROCMComputeVersion(const std::string& target) {
size_t pos = target.find("=gfx");
if (pos != std::string::npos) {
int value;
std::stringstream is(target.substr(pos + 4));
if (is >> value) return value;
}
TVMContext tvm_ctx;
tvm_ctx.device_type = kDLROCM;
tvm_ctx.device_id = 0;
TVMRetValue val;
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
tvm_ctx, tvm::runtime::kExist, &val);
if (val.operator int() == 1) {
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
return val.operator int();
} else {
return 803;
tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true);
if (api != nullptr) {
TVMRetValue val;
api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
if (val.operator int() == 1) {
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
return val.operator int();
}
}
LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx803";
return 803;
}
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
......@@ -151,7 +162,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
target.substr(0, 4) == "rocm");
std::ostringstream config;
config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
<< DetectROCMComputeVersion()
<< DetectROCMComputeVersion(target)
<< target.substr(4, target.length() - 4);
llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
......@@ -216,7 +227,19 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
std::string hsaco = (*f)(arr);
std::string ll(data_ll.begin(), data_ll.end());
#if TVM_ROCM_RUNTIME
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly);
#else
LOG(WARNING) << "ROCM runtime is not enabled, return a source module...";
auto fget_source = [ll, assembly](const std::string& format) {
if (format.length() == 0) return assembly;
if (format == "ll" || format == "llvm") return format;
if (format == "asm") return assembly;
return std::string("");
};
return DeviceSourceModuleCreate(
hsaco, "hsaco", ExtractFuncInfo(funcs), "hsaco", fget_source);
#endif // TVM_ROCM_RUNTIME
}
TVM_REGISTER_API("codegen.build_rocm")
......@@ -226,5 +249,4 @@ TVM_REGISTER_API("codegen.build_rocm")
} // namespace codegen
} // namespace tvm
#endif // TVM_ROCM_RUNTIME
#endif // TVM_LLVM_VERSION
......@@ -54,46 +54,71 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
}
// supports limited save without cross compile
class DeviceSourceModuleNode final : public SourceModuleNode {
class DeviceSourceModuleNode final : public runtime::ModuleNode {
public:
DeviceSourceModuleNode(std::string code,
DeviceSourceModuleNode(std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key)
: SourceModuleNode(code, fmt), fmap_(fmap), type_key_(type_key) {}
std::string type_key,
std::function<std::string(const std::string&)> fget_source)
: data_(data),
fmt_(fmt),
fmap_(fmap),
type_key_(type_key),
fget_source_(fget_source) {}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
std::string GetSource(const std::string& format) final {
if (fget_source_ != nullptr) {
return fget_source_(format);
} else {
return data_;
}
}
const char* type_key() const {
return type_key_.c_str();
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, code_);
SaveBinaryToFile(file_name, data_);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(code_);
stream->Write(data_);
}
private:
std::string data_;
std::string fmt_;
std::unordered_map<std::string, FunctionInfo> fmap_;
std::string type_key_;
std::function<std::string(const std::string&)> fget_source_;
};
runtime::Module DeviceSourceModuleCreate(
std::string code,
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key) {
std::string type_key,
std::function<std::string(const std::string&)> fget_source) {
std::shared_ptr<DeviceSourceModuleNode> n =
std::make_shared<DeviceSourceModuleNode>(code, fmt, fmap, type_key);
std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
return runtime::Module(n);
}
......
......@@ -121,9 +121,9 @@ bool RuntimeEnabled(const std::string& target) {
} else if (target == "vpi" || target == "verilog") {
f_name = "device_api.vpi";
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
f_name = "codegen.build_nvptx";
f_name = "device_api.gpu";
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
f_name = "codegen.build_rocm";
f_name = "device_api.rocm";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
if (pf == nullptr) return false;
......
......@@ -41,13 +41,13 @@ def verify_l2norm(n, c, h, w, eps, axis=None):
b_np = l2norm_instance_python(a_np, eps, axis)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_l2norm(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
......
......@@ -70,13 +70,13 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
b_np = lrn_python(a_np, size, axis, bias, alpha, beta)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_lrn(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
......
......@@ -29,7 +29,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -40,7 +41,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
s = topi.cpp.rocm.schedule_dense(target, [D])
else:
s = topi.cpp.cuda.schedule_dense(target, [D])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
......
......@@ -48,7 +48,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -57,7 +58,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_pool(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
......
......@@ -46,7 +46,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
raise NotImplementedError
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -56,7 +57,6 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
else:
s = topi.cpp.cuda.schedule_reduce(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="sum")
# Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
......
......@@ -7,7 +7,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.cpp.expand_dims(A, axis, num_newaxis)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -16,7 +17,6 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
s = topi.cpp.generic.schedule_injective(target, [B])
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = data_npy.reshape(out_shape)
......@@ -33,7 +33,8 @@ def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.cpp.transpose(A, axes)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -59,7 +60,8 @@ def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.cpp.reshape(A, dst_shape)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -68,7 +70,6 @@ def verify_reshape(src_shape, dst_shape):
s = topi.cpp.generic.schedule_injective(target, [B])
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.reshape(data_npy, newshape=dst_shape)
......@@ -85,7 +86,8 @@ def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.cpp.squeeze(A, axis)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -94,7 +96,6 @@ def verify_squeeze(src_shape, axis):
s = topi.cpp.generic.schedule_injective(target, [B])
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis)
......@@ -116,7 +117,8 @@ def verify_concatenate(shapes, axis):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.cpp.concatenate(tensor_l, axis)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -125,7 +127,6 @@ def verify_concatenate(shapes, axis):
s = topi.cpp.generic.schedule_injective(target, [out_tensor])
else:
s = topi.cpp.cuda.schedule_injective(target, [out_tensor])
ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
out_npy = np.concatenate(data_npys, axis=axis)
......@@ -143,7 +144,8 @@ def verify_split(src_shape, indices_or_sections, axis):
tensor_l = topi.cpp.split(A, indices_or_sections, axis)
tensor_l = list(tensor_l)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment