Commit ec3a4251 by Marina Kolpakova Committed by Tianqi Chen

A couple of fixes for GEN (#2593)

parent 77718f8e
...@@ -302,7 +302,7 @@ def lower(sch, ...@@ -302,7 +302,7 @@ def lower(sch,
Parameters Parameters
---------- ----------
sch : tvm.schedule.Schedule sch : tvm.schedule.Schedule
The schedule to be builded The schedule to be built
args : list of Buffer or Tensor or Var args : list of Buffer or Tensor or Var
The argument lists to the function. The argument lists to the function.
......
...@@ -159,11 +159,12 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -159,11 +159,12 @@ class GraphModuleDebug(graph_runtime.GraphModule):
self.debug_datum = debug_result.DebugResult(graph_json, self._dump_path) self.debug_datum = debug_result.DebugResult(graph_json, self._dump_path)
def _run_debug(self): def _run_debug(self):
"""Execute the node spcified with index will be executed. """Execute the node specified with index will be executed.
Each debug output will be copied to the buffer Each debug output will be copied to the buffer
Time consumed for each execuion will be set as debug output. Time consumed for each execution will be set as debug output.
""" """
self.debug_datum._time_list = []
for i, node in enumerate(self.debug_datum.get_graph_nodes()): for i, node in enumerate(self.debug_datum.get_graph_nodes()):
start_time = datetime.now().time() start_time = datetime.now().time()
...@@ -177,7 +178,7 @@ class GraphModuleDebug(graph_runtime.GraphModule): ...@@ -177,7 +178,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
self.debug_datum._output_tensor_list.append(out_tensor) self.debug_datum._output_tensor_list.append(out_tensor)
def debug_get_output(self, node, out): def debug_get_output(self, node, out):
"""Run graph upto node and get the output to out """Run graph up to node and get the output to out
Parameters Parameters
---------- ----------
......
...@@ -130,7 +130,7 @@ struct ThreadScope { ...@@ -130,7 +130,7 @@ struct ThreadScope {
}; };
/*! \brief workload speccification */ /*! \brief workload specification */
struct ThreadWorkLoad { struct ThreadWorkLoad {
// array, first three are thread configuration. // array, first three are thread configuration.
size_t work_size[6]; size_t work_size[6];
......
...@@ -31,7 +31,7 @@ import numpy as np ...@@ -31,7 +31,7 @@ import numpy as np
###################################################################### ######################################################################
# We first write a very simple vector add and build it with the default schedule. Then, we use # We first write a very simple vector add and build it with the default schedule. Then, we use
# our customized lowering pass to manipulate the IR directly instead of using schedule premitives. # our customized lowering pass to manipulate the IR directly instead of using schedule primitives.
# #
n = tvm.const(128, "int32") n = tvm.const(128, "int32")
......
...@@ -94,7 +94,7 @@ bx, tx = s[C].split(C.op.axis[0], factor=64) ...@@ -94,7 +94,7 @@ bx, tx = s[C].split(C.op.axis[0], factor=64)
# compute grid. These are GPU specific constructs that allows us # compute grid. These are GPU specific constructs that allows us
# to generate code that runs on GPU. # to generate code that runs on GPU.
# #
if tgt == "cuda": if tgt == "cuda" or tgt.startswith('opencl'):
s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x")) s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
...@@ -149,7 +149,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) ...@@ -149,7 +149,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
# #
# The following code fetches the device module and prints the content code. # The following code fetches the device module and prints the content code.
# #
if tgt == "cuda": if tgt == "cuda" or tgt.startswith('opencl'):
dev_module = fadd.imported_modules[0] dev_module = fadd.imported_modules[0]
print("-----GPU code-----") print("-----GPU code-----")
print(dev_module.get_source()) print(dev_module.get_source())
...@@ -193,6 +193,8 @@ temp = util.tempdir() ...@@ -193,6 +193,8 @@ temp = util.tempdir()
fadd.save(temp.relpath("myadd.o")) fadd.save(temp.relpath("myadd.o"))
if tgt == "cuda": if tgt == "cuda":
fadd.imported_modules[0].save(temp.relpath("myadd.ptx")) fadd.imported_modules[0].save(temp.relpath("myadd.ptx"))
if tgt.startswith('opencl'):
fadd.imported_modules[0].save(temp.relpath("myadd.cl"))
cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")]) cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
print(temp.listdir()) print(temp.listdir())
...@@ -200,29 +202,34 @@ print(temp.listdir()) ...@@ -200,29 +202,34 @@ print(temp.listdir())
# .. note:: Module Storage Format # .. note:: Module Storage Format
# #
# The CPU(host) module is directly saved as a shared library(so). # The CPU(host) module is directly saved as a shared library(so).
# There can be multiple customed format on the device code. # There can be multiple customized format on the device code.
# In our example, device code is stored in ptx, as well as a meta # In our example, device code is stored in ptx, as well as a meta
# data json file. They can be loaded and linked seperatedly via import. # data json file. They can be loaded and linked separately via import.
# #
###################################################################### ######################################################################
# Load Compiled Module # Load Compiled Module
# -------------------- # --------------------
# We can load the compiled module from the file system and run the code. # We can load the compiled module from the file system and run the code.
# The following code load the host and device module seperatedly and # The following code load the host and device module separately and
# re-link them together. We can verify that the newly loaded function works. # re-link them together. We can verify that the newly loaded function works.
# #
fadd1 = tvm.module.load(temp.relpath("myadd.so")) fadd1 = tvm.module.load(temp.relpath("myadd.so"))
if tgt == "cuda": if tgt == "cuda":
fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx")) fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx"))
fadd1.import_module(fadd1_dev) fadd1.import_module(fadd1_dev)
if tgt.startswith('opencl'):
fadd1_dev = tvm.module.load(temp.relpath("myadd.cl"))
fadd1.import_module(fadd1_dev)
fadd1(a, b, c) fadd1(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
###################################################################### ######################################################################
# Pack Everything into One Library # Pack Everything into One Library
# -------------------------------- # --------------------------------
# In the above example, we store the device and host code seperatedly. # In the above example, we store the device and host code separately.
# TVM also supports export everything as one shared library. # TVM also supports export everything as one shared library.
# Under the hood, we pack the device modules into binary blobs and link # Under the hood, we pack the device modules into binary blobs and link
# them together with the host code. # them together with the host code.
...@@ -254,8 +261,8 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) ...@@ -254,8 +261,8 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
# The following codeblocks generate opencl code, creates array on opencl # The following codeblocks generate opencl code, creates array on opencl
# device, and verifies the correctness of the code. # device, and verifies the correctness of the code.
# #
if tgt == "opencl": if tgt.startswith('opencl'):
fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd") fadd_cl = tvm.build(s, [A, B, C], tgt, name="myadd")
print("------opencl code------") print("------opencl code------")
print(fadd_cl.imported_modules[0].get_source()) print(fadd_cl.imported_modules[0].get_source())
ctx = tvm.cl(0) ctx = tvm.cl(0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment