Commit b5b51f0d by MORITA Kazutaka Committed by Tianqi Chen

[OPENCL][RUNTIME] Fix race condition of modules (#2018)

parent af974c34
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
see README.md for the usage and results of this script. see README.md for the usage and results of this script.
""" """
import argparse import argparse
import threading
import numpy as np import numpy as np
...@@ -14,6 +15,26 @@ import nnvm.testing ...@@ -14,6 +15,26 @@ import nnvm.testing
from util import get_network from util import get_network
def benchmark(network, target):
net, params, input_shape, output_shape = get_network(network, batch_size=1)
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
# create runtime
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=args.repeat)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices= parser.add_argument("--network", type=str, choices=
...@@ -29,6 +50,7 @@ if __name__ == "__main__": ...@@ -29,6 +50,7 @@ if __name__ == "__main__":
parser.add_argument("--target", type=str, parser.add_argument("--target", type=str,
choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal'], default='cuda', choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal'], default='cuda',
help="The tvm compilation target") help="The tvm compilation target")
parser.add_argument("--thread", type=int, default=1, help="The number of threads to be run.")
args = parser.parse_args() args = parser.parse_args()
dtype = 'float32' dtype = 'float32'
...@@ -44,20 +66,16 @@ if __name__ == "__main__": ...@@ -44,20 +66,16 @@ if __name__ == "__main__":
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)")) print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
print("--------------------------------------------------") print("--------------------------------------------------")
for network in networks: for network in networks:
net, params, input_shape, output_shape = get_network(network, batch_size=1) if args.thread == 1:
benchmark(network, target)
with nnvm.compiler.build_config(opt_level=3): else:
graph, lib, params = nnvm.compiler.build( threads = list()
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype) for n in range(args.thread):
thread = threading.Thread(target=benchmark, args=([network, target]), name="thread%d" % n)
# create runtime threads.append(thread)
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx) for thread in threads:
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) thread.start()
module.set_input('data', data_tvm)
module.set_input(**params) for thread in threads:
thread.join()
# evaluate
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=args.repeat)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
...@@ -232,7 +232,6 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic ...@@ -232,7 +232,6 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
if (initialized_) return; if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mu); std::lock_guard<std::mutex> lock(this->mu);
if (initialized_) return; if (initialized_) return;
initialized_ = true;
if (context != nullptr) return; if (context != nullptr) return;
// matched platforms // matched platforms
std::vector<cl_platform_id> platform_ids = cl::GetPlatformIDs(); std::vector<cl_platform_id> platform_ids = cl::GetPlatformIDs();
...@@ -271,6 +270,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic ...@@ -271,6 +270,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
clCreateCommandQueue(this->context, did, 0, &err_code)); clCreateCommandQueue(this->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code); OPENCL_CHECK_ERROR(err_code);
} }
initialized_ = true;
} }
TVM_REGISTER_GLOBAL("device_api.opencl") TVM_REGISTER_GLOBAL("device_api.opencl")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment