Commit b5b51f0d by MORITA Kazutaka Committed by Tianqi Chen

[OPENCL][RUNTIME] Fix race condition of modules (#2018)

parent af974c34
......@@ -2,6 +2,7 @@
see README.md for the usage and results of this script.
"""
import argparse
import threading
import numpy as np
......@@ -14,6 +15,26 @@ import nnvm.testing
from util import get_network
def benchmark(network, target):
net, params, input_shape, output_shape = get_network(network, batch_size=1)
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
# create runtime
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=args.repeat)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices=
......@@ -29,6 +50,7 @@ if __name__ == "__main__":
parser.add_argument("--target", type=str,
choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal'], default='cuda',
help="The tvm compilation target")
parser.add_argument("--thread", type=int, default=1, help="The number of threads to be run.")
args = parser.parse_args()
dtype = 'float32'
......@@ -44,20 +66,16 @@ if __name__ == "__main__":
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
print("--------------------------------------------------")
for network in networks:
net, params, input_shape, output_shape = get_network(network, batch_size=1)
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
# create runtime
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=args.repeat)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
if args.thread == 1:
benchmark(network, target)
else:
threads = list()
for n in range(args.thread):
thread = threading.Thread(target=benchmark, args=([network, target]), name="thread%d" % n)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
......@@ -232,7 +232,6 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mu);
if (initialized_) return;
initialized_ = true;
if (context != nullptr) return;
// matched platforms
std::vector<cl_platform_id> platform_ids = cl::GetPlatformIDs();
......@@ -271,6 +270,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
clCreateCommandQueue(this->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code);
}
initialized_ = true;
}
TVM_REGISTER_GLOBAL("device_api.opencl")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment