Commit 51c40b4f by Tianqi Chen Committed by GitHub

[CODEGEN] Enable cross compile of AMDGPU without rocm, update rpc (#1154)

parent 11c7b6cf
...@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""): ...@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
def _connect(key): def _connect(key):
conn = yield websocket.websocket_connect(url) conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn) on_message = create_on_message(conn)
temp = _server_env() temp = _server_env(None)
# Start connecton # Start connecton
conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True) conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True)
key = "server:" + key key = "server:" + key
......
...@@ -11,6 +11,7 @@ Server is TCP based with the following protocol: ...@@ -11,6 +11,7 @@ Server is TCP based with the following protocol:
from __future__ import absolute_import from __future__ import absolute_import
import os import os
import ctypes
import socket import socket
import select import select
import struct import struct
...@@ -21,12 +22,13 @@ import time ...@@ -21,12 +22,13 @@ import time
from ..._ffi.function import register_func from ..._ffi.function import register_func
from ..._ffi.base import py_str from ..._ffi.base import py_str
from ..._ffi.libinfo import find_lib_path
from ...module import load as _load_module from ...module import load as _load_module
from .. import util from .. import util
from . import base from . import base
from . base import TrackerCode from . base import TrackerCode
def _server_env(): def _server_env(load_library):
"""Server environment function return temp dir""" """Server environment function return temp dir"""
temp = util.tempdir() temp = util.tempdir()
# pylint: disable=unused-variable # pylint: disable=unused-variable
...@@ -41,13 +43,21 @@ def _server_env(): ...@@ -41,13 +43,21 @@ def _server_env():
m = _load_module(path) m = _load_module(path)
logging.info("load_module %s", path) logging.info("load_module %s", path)
return m return m
libs = []
load_library = load_library.split(":") if load_library else []
for file_name in load_library:
file_name = find_lib_path(file_name)[0]
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
temp.libs = libs
return temp return temp
def _serve_loop(sock, addr): def _serve_loop(sock, addr, load_library):
"""Server loop""" """Server loop"""
sockfd = sock.fileno() sockfd = sock.fileno()
temp = _server_env() temp = _server_env(load_library)
base._ServerLoop(sockfd) base._ServerLoop(sockfd)
temp.remove() temp.remove()
logging.info("Finish serving %s", addr) logging.info("Finish serving %s", addr)
...@@ -62,7 +72,7 @@ def _parse_server_opt(opts): ...@@ -62,7 +72,7 @@ def _parse_server_opt(opts):
return ret return ret
def _listen_loop(sock, port, rpc_key, tracker_addr): def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
"""Lisenting loop of the server master.""" """Lisenting loop of the server master."""
def _accept_conn(listen_sock, tracker_conn, ping_period=2): def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places. """Accept connection from the other places.
...@@ -162,7 +172,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -162,7 +172,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# step 3: serving # step 3: serving
logging.info("RPCServer: connection from %s", addr) logging.info("RPCServer: connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr)) server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr, load_library))
server_proc.deamon = True server_proc.deamon = True
server_proc.start() server_proc.start()
# close from our side. # close from our side.
...@@ -174,7 +184,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -174,7 +184,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
server_proc.terminate() server_proc.terminate()
def _connect_proxy_loop(addr, key): def _connect_proxy_loop(addr, key, load_library):
key = "server:" + key key = "server:" + key
retry_count = 0 retry_count = 0
max_retry = 5 max_retry = 5
...@@ -198,7 +208,7 @@ def _connect_proxy_loop(addr, key): ...@@ -198,7 +208,7 @@ def _connect_proxy_loop(addr, key):
opts = _parse_server_opt(remote_key.split()[1:]) opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr)) logging.info("RPCProxy connected to %s", str(addr))
process = multiprocessing.Process( process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr)) target=_serve_loop, args=(sock, addr, load_library))
process.deamon = True process.deamon = True
process.start() process.start()
sock.close() sock.close()
...@@ -256,6 +266,9 @@ class Server(object): ...@@ -256,6 +266,9 @@ class Server(object):
key : str, optional key : str, optional
The key used to identify the server in Proxy connection. The key used to identify the server in Proxy connection.
load_library : str, optional
List of additional libraries to be loaded during execution.
""" """
def __init__(self, def __init__(self,
host, host,
...@@ -264,7 +277,8 @@ class Server(object): ...@@ -264,7 +277,8 @@ class Server(object):
is_proxy=False, is_proxy=False,
use_popen=False, use_popen=False,
tracker_addr=None, tracker_addr=None,
key=""): key="",
load_library=None):
try: try:
if base._ServerLoop is None: if base._ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1") raise RuntimeError("Please compile with USE_RPC=1")
...@@ -283,6 +297,8 @@ class Server(object): ...@@ -283,6 +297,8 @@ class Server(object):
assert key assert key
cmd += ["--tracker=%s:%d" % tracker_addr, cmd += ["--tracker=%s:%d" % tracker_addr,
"--key=%s" % key] "--key=%s" % key]
if load_library:
cmd += ["--load-libary", load_library]
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,)) target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True self.proc.deamon = True
...@@ -308,12 +324,12 @@ class Server(object): ...@@ -308,12 +324,12 @@ class Server(object):
self.sock = sock self.sock = sock
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_listen_loop, args=( target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr)) self.sock, self.port, key, tracker_addr, load_library))
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
else: else:
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key)) target=_connect_proxy_loop, args=((host, port), key, load_library))
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
......
"""Start an RPC server""" """Start an RPC server"""
from __future__ import absolute_import from __future__ import absolute_import
import logging
import argparse import argparse
import os
import ctypes
from ..contrib import rpc from ..contrib import rpc
from .._ffi.libinfo import find_lib_path
def main(): def main():
"""Main funciton""" """Main funciton"""
...@@ -19,26 +15,12 @@ def main(): ...@@ -19,26 +15,12 @@ def main():
help='The end search port of the PRC') help='The end search port of the PRC')
parser.add_argument('--key', type=str, default="", parser.add_argument('--key', type=str, default="",
help="RPC key used to identify the connection type.") help="RPC key used to identify the connection type.")
parser.add_argument('--with-executor', type=bool, default=False,
help="Whether to load executor runtime")
parser.add_argument('--load-library', type=str, default="", parser.add_argument('--load-library', type=str, default="",
help="Additional library to load") help="Additional library to load")
parser.add_argument('--tracker', type=str, default="", parser.add_argument('--tracker', type=str, default="",
help="Report to RPC tracker") help="Report to RPC tracker")
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
load_library = [lib for lib in args.load_library.split(":") if len(lib) != 0]
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
apps_path = os.path.join(curr_path, "../../../apps/graph_executor/lib/")
libs = []
if args.with_executor:
load_library += ["libtvm_graph_exec.so"]
for file_name in load_library:
file_name = find_lib_path(file_name, apps_path)[0]
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
if args.tracker: if args.tracker:
url, port = args.tracker.split(":") url, port = args.tracker.split(":")
port = int(port) port = int(port)
...@@ -53,8 +35,8 @@ def main(): ...@@ -53,8 +35,8 @@ def main():
args.port, args.port,
args.port_end, args.port_end,
key=args.key, key=args.key,
tracker_addr=tracker_addr) tracker_addr=tracker_addr,
server.libs += libs load_library=args.load_library)
server.proc.join() server.proc.join()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <functional>
#include <unordered_map> #include <unordered_map>
#include "../runtime/meta_data.h" #include "../runtime/meta_data.h"
...@@ -111,17 +112,19 @@ class CodeGenSourceBase { ...@@ -111,17 +112,19 @@ class CodeGenSourceBase {
runtime::Module SourceModuleCreate(std::string code, std::string fmt); runtime::Module SourceModuleCreate(std::string code, std::string fmt);
/*! /*!
* \brief Create a source module for viewing and limited saving * \brief Create a source module for viewing and limited saving for device.
* \param code The code to be viewed. * \param data The code data to be viewed.
* \param fmt The code. format. * \param fmt The code. format.
* \param fmap The map function information map of each function. * \param fmap The map function information map of each function.
* \param type_key The type_key of the runtime module of this source code * \param type_key The type_key of the runtime module of this source code
* \param fget_source a closure to replace default get source behavior.
*/ */
runtime::Module DeviceSourceModuleCreate( runtime::Module DeviceSourceModuleCreate(
std::string code, std::string data,
std::string fmt, std::string fmt,
std::unordered_map<std::string, runtime::FunctionInfo> fmap, std::unordered_map<std::string, runtime::FunctionInfo> fmap,
std::string type_key); std::string type_key,
std::function<std::string(const std::string&)> fget_source = nullptr);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ #endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
...@@ -4,15 +4,18 @@ ...@@ -4,15 +4,18 @@
* \brief AMDGPU code generator. * \brief AMDGPU code generator.
*/ */
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#if TVM_ROCM_RUNTIME
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include "./codegen_llvm.h" #include "./codegen_llvm.h"
#include "../build_common.h" #include "../build_common.h"
#include "../codegen_source_base.h"
#include "../../pass/ir_util.h" #include "../../pass/ir_util.h"
#if TVM_ROCM_RUNTIME
#include "../../runtime/rocm/rocm_module.h" #include "../../runtime/rocm/rocm_module.h"
#endif // TVM_ROCM_RUNTIME
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -131,19 +134,27 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -131,19 +134,27 @@ class CodeGenAMDGPU : public CodeGenLLVM {
} }
}; };
inline int DetectROCMComputeVersion() { inline int DetectROCMComputeVersion(const std::string& target) {
size_t pos = target.find("=gfx");
if (pos != std::string::npos) {
int value;
std::stringstream is(target.substr(pos + 4));
if (is >> value) return value;
}
TVMContext tvm_ctx; TVMContext tvm_ctx;
tvm_ctx.device_type = kDLROCM; tvm_ctx.device_type = kDLROCM;
tvm_ctx.device_id = 0; tvm_ctx.device_id = 0;
TVMRetValue val; tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true);
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( if (api != nullptr) {
tvm_ctx, tvm::runtime::kExist, &val); TVMRetValue val;
if (val.operator int() == 1) { api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); if (val.operator int() == 1) {
return val.operator int(); tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
} else { return val.operator int();
return 803; }
} }
LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx803";
return 803;
} }
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
...@@ -151,7 +162,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { ...@@ -151,7 +162,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
target.substr(0, 4) == "rocm"); target.substr(0, 4) == "rocm");
std::ostringstream config; std::ostringstream config;
config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
<< DetectROCMComputeVersion() << DetectROCMComputeVersion(target)
<< target.substr(4, target.length() - 4); << target.substr(4, target.length() - 4);
llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str()); llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU()); std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
...@@ -216,7 +227,19 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { ...@@ -216,7 +227,19 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
std::string hsaco = (*f)(arr); std::string hsaco = (*f)(arr);
std::string ll(data_ll.begin(), data_ll.end()); std::string ll(data_ll.begin(), data_ll.end());
#if TVM_ROCM_RUNTIME
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly); return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly);
#else
LOG(WARNING) << "ROCM runtime is not enabled, return a source module...";
auto fget_source = [ll, assembly](const std::string& format) {
if (format.length() == 0) return assembly;
if (format == "ll" || format == "llvm") return format;
if (format == "asm") return assembly;
return std::string("");
};
return DeviceSourceModuleCreate(
hsaco, "hsaco", ExtractFuncInfo(funcs), "hsaco", fget_source);
#endif // TVM_ROCM_RUNTIME
} }
TVM_REGISTER_API("codegen.build_rocm") TVM_REGISTER_API("codegen.build_rocm")
...@@ -226,5 +249,4 @@ TVM_REGISTER_API("codegen.build_rocm") ...@@ -226,5 +249,4 @@ TVM_REGISTER_API("codegen.build_rocm")
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_ROCM_RUNTIME
#endif // TVM_LLVM_VERSION #endif // TVM_LLVM_VERSION
...@@ -54,46 +54,71 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { ...@@ -54,46 +54,71 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
} }
// supports limited save without cross compile // supports limited save without cross compile
class DeviceSourceModuleNode final : public SourceModuleNode { class DeviceSourceModuleNode final : public runtime::ModuleNode {
public: public:
DeviceSourceModuleNode(std::string code, DeviceSourceModuleNode(std::string data,
std::string fmt, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key) std::string type_key,
: SourceModuleNode(code, fmt), fmap_(fmap), type_key_(type_key) {} std::function<std::string(const std::string&)> fget_source)
: data_(data),
fmt_(fmt),
fmap_(fmap),
type_key_(type_key),
fget_source_(fget_source) {}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
std::string GetSource(const std::string& format) final {
if (fget_source_ != nullptr) {
return fget_source_(format);
} else {
return data_;
}
}
const char* type_key() const { const char* type_key() const {
return type_key_.c_str(); return type_key_.c_str();
} }
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_) CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_; << "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name); std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_); SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, code_); SaveBinaryToFile(file_name, data_);
} }
void SaveToBinary(dmlc::Stream* stream) final { void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_); stream->Write(fmt_);
stream->Write(fmap_); stream->Write(fmap_);
stream->Write(code_); stream->Write(data_);
} }
private: private:
std::string data_;
std::string fmt_;
std::unordered_map<std::string, FunctionInfo> fmap_; std::unordered_map<std::string, FunctionInfo> fmap_;
std::string type_key_; std::string type_key_;
std::function<std::string(const std::string&)> fget_source_;
}; };
runtime::Module DeviceSourceModuleCreate( runtime::Module DeviceSourceModuleCreate(
std::string code, std::string data,
std::string fmt, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key) { std::string type_key,
std::function<std::string(const std::string&)> fget_source) {
std::shared_ptr<DeviceSourceModuleNode> n = std::shared_ptr<DeviceSourceModuleNode> n =
std::make_shared<DeviceSourceModuleNode>(code, fmt, fmap, type_key); std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
return runtime::Module(n); return runtime::Module(n);
} }
......
...@@ -121,9 +121,9 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -121,9 +121,9 @@ bool RuntimeEnabled(const std::string& target) {
} else if (target == "vpi" || target == "verilog") { } else if (target == "vpi" || target == "verilog") {
f_name = "device_api.vpi"; f_name = "device_api.vpi";
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") { } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
f_name = "codegen.build_nvptx"; f_name = "device_api.gpu";
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") { } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
f_name = "codegen.build_rocm"; f_name = "device_api.rocm";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") { } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled"); const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
if (pf == nullptr) return false; if (pf == nullptr) return false;
......
...@@ -41,13 +41,13 @@ def verify_l2norm(n, c, h, w, eps, axis=None): ...@@ -41,13 +41,13 @@ def verify_l2norm(n, c, h, w, eps, axis=None):
b_np = l2norm_instance_python(a_np, eps, axis) b_np = l2norm_instance_python(a_np, eps, axis)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_l2norm(B) s = topi.generic.schedule_l2norm(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
......
...@@ -70,13 +70,13 @@ def verify_lrn(shape, size, axis, bias, alpha, beta): ...@@ -70,13 +70,13 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
b_np = lrn_python(a_np, size, axis, bias, alpha, beta) b_np = lrn_python(a_np, size, axis, bias, alpha, beta)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_lrn(B) s = topi.generic.schedule_lrn(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
......
...@@ -29,7 +29,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -29,7 +29,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
a_np, b_np, c_np, d_np = get_ref_data() a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -40,7 +41,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -40,7 +41,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
s = topi.cpp.rocm.schedule_dense(target, [D]) s = topi.cpp.rocm.schedule_dense(target, [D])
else: else:
s = topi.cpp.cuda.schedule_dense(target, [D]) s = topi.cpp.cuda.schedule_dense(target, [D])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx) c = tvm.nd.array(c_np, ctx)
......
...@@ -48,7 +48,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): ...@@ -48,7 +48,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
b_np = np.maximum(b_np, 0.0) b_np = np.maximum(b_np, 0.0)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -57,7 +58,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): ...@@ -57,7 +58,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
s = topi.cpp.generic.default_schedule(target, [B], False) s = topi.cpp.generic.default_schedule(target, [B], False)
else: else:
s = topi.cpp.cuda.schedule_pool(target, [B]) s = topi.cpp.cuda.schedule_pool(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
......
...@@ -46,7 +46,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -46,7 +46,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
raise NotImplementedError raise NotImplementedError
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -56,7 +57,6 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -56,7 +57,6 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
else: else:
s = topi.cpp.cuda.schedule_reduce(target, [B]) s = topi.cpp.cuda.schedule_reduce(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="sum") foo = tvm.build(s, [A, B], device, name="sum")
# Test # Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32) in_npy = np.random.uniform(size=in_shape).astype(np.float32)
......
...@@ -7,7 +7,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -7,7 +7,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.cpp.expand_dims(A, axis, num_newaxis) B = topi.cpp.expand_dims(A, axis, num_newaxis)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -16,7 +17,6 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -16,7 +17,6 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
s = topi.cpp.generic.schedule_injective(target, [B]) s = topi.cpp.generic.schedule_injective(target, [B])
else: else:
s = topi.cpp.cuda.schedule_injective(target, [B]) s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims") foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = data_npy.reshape(out_shape) out_npy = data_npy.reshape(out_shape)
...@@ -33,7 +33,8 @@ def verify_tranpose(in_shape, axes): ...@@ -33,7 +33,8 @@ def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.cpp.transpose(A, axes) B = topi.cpp.transpose(A, axes)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -59,7 +60,8 @@ def verify_reshape(src_shape, dst_shape): ...@@ -59,7 +60,8 @@ def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
B = topi.cpp.reshape(A, dst_shape) B = topi.cpp.reshape(A, dst_shape)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -68,7 +70,6 @@ def verify_reshape(src_shape, dst_shape): ...@@ -68,7 +70,6 @@ def verify_reshape(src_shape, dst_shape):
s = topi.cpp.generic.schedule_injective(target, [B]) s = topi.cpp.generic.schedule_injective(target, [B])
else: else:
s = topi.cpp.cuda.schedule_injective(target, [B]) s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape") foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.reshape(data_npy, newshape=dst_shape) out_npy = np.reshape(data_npy, newshape=dst_shape)
...@@ -85,7 +86,8 @@ def verify_squeeze(src_shape, axis): ...@@ -85,7 +86,8 @@ def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
B = topi.cpp.squeeze(A, axis) B = topi.cpp.squeeze(A, axis)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -94,7 +96,6 @@ def verify_squeeze(src_shape, axis): ...@@ -94,7 +96,6 @@ def verify_squeeze(src_shape, axis):
s = topi.cpp.generic.schedule_injective(target, [B]) s = topi.cpp.generic.schedule_injective(target, [B])
else: else:
s = topi.cpp.cuda.schedule_injective(target, [B]) s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze") foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis) out_npy = np.squeeze(data_npy, axis=axis)
...@@ -116,7 +117,8 @@ def verify_concatenate(shapes, axis): ...@@ -116,7 +117,8 @@ def verify_concatenate(shapes, axis):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i))) tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.cpp.concatenate(tensor_l, axis) out_tensor = topi.cpp.concatenate(tensor_l, axis)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -125,7 +127,6 @@ def verify_concatenate(shapes, axis): ...@@ -125,7 +127,6 @@ def verify_concatenate(shapes, axis):
s = topi.cpp.generic.schedule_injective(target, [out_tensor]) s = topi.cpp.generic.schedule_injective(target, [out_tensor])
else: else:
s = topi.cpp.cuda.schedule_injective(target, [out_tensor]) s = topi.cpp.cuda.schedule_injective(target, [out_tensor])
ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
out_npy = np.concatenate(data_npys, axis=axis) out_npy = np.concatenate(data_npys, axis=axis)
...@@ -143,7 +144,8 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -143,7 +144,8 @@ def verify_split(src_shape, indices_or_sections, axis):
tensor_l = topi.cpp.split(A, indices_or_sections, axis) tensor_l = topi.cpp.split(A, indices_or_sections, axis)
tensor_l = list(tensor_l) tensor_l = list(tensor_l)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment