Commit e3ddc8da by Siva Committed by Tianqi Chen

[DOC] Generalize the get_started script for beginners with different environments. (#798)

parent d1cdb623
...@@ -13,6 +13,12 @@ from __future__ import absolute_import, print_function ...@@ -13,6 +13,12 @@ from __future__ import absolute_import, print_function
import tvm import tvm
import numpy as np import numpy as np
# Global declarations of environment.
tgt_host="llvm"
# Change it to respective GPU if gpu is enabled Ex: cuda, opencl
tgt="cuda"
###################################################################### ######################################################################
# Vector Add Example # Vector Add Example
# ------------------ # ------------------
...@@ -88,8 +94,9 @@ bx, tx = s[C].split(C.op.axis[0], factor=64) ...@@ -88,8 +94,9 @@ bx, tx = s[C].split(C.op.axis[0], factor=64)
# compute grid. These are GPU specific constructs that allows us # compute grid. These are GPU specific constructs that allows us
# to generate code that runs on GPU. # to generate code that runs on GPU.
# #
s[C].bind(bx, tvm.thread_axis("blockIdx.x")) if tgt == "cuda":
s[C].bind(tx, tvm.thread_axis("threadIdx.x")) s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
###################################################################### ######################################################################
# Compilation # Compilation
...@@ -103,12 +110,12 @@ s[C].bind(tx, tvm.thread_axis("threadIdx.x")) ...@@ -103,12 +110,12 @@ s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
# function(including the inputs and outputs) as well as target language # function(including the inputs and outputs) as well as target language
# we want to compile to. # we want to compile to.
# #
# The result of compilation fadd is a CUDA device function that can # The result of compilation fadd is a GPU device function(if GPU is involved)
# as well as a host wrapper that calls into the CUDA function. # that can as well as a host wrapper that calls into the GPU function.
# fadd is the generated host wrapper function, it contains reference # fadd is the generated host wrapper function, it contains reference
# to the generated device function internally. # to the generated device function internally.
# #
fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
###################################################################### ######################################################################
# Run the Function # Run the Function
...@@ -124,12 +131,13 @@ fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") ...@@ -124,12 +131,13 @@ fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
# - fadd runs the actual computation. # - fadd runs the actual computation.
# - asnumpy() copies the gpu array back to cpu and we can use this to verify correctness # - asnumpy() copies the gpu array back to cpu and we can use this to verify correctness
# #
ctx = tvm.gpu(0) ctx = tvm.context(tgt, 0)
n = 1024 n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
fadd_cuda(a, b, c) fadd(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
###################################################################### ######################################################################
...@@ -137,13 +145,16 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) ...@@ -137,13 +145,16 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
# -------------------------- # --------------------------
# You can inspect the generated code in TVM. The result of tvm.build # You can inspect the generated code in TVM. The result of tvm.build
# is a tvm Module. fadd is the host module that contains the host wrapper, # is a tvm Module. fadd is the host module that contains the host wrapper,
# it also contains a device module for the CUDA function. # it also contains a device module for the CUDA (GPU) function.
# #
# The following code fetches the device module and prints the content code. # The following code fetches the device module and prints the content code.
# #
dev_module = fadd_cuda.imported_modules[0] if tgt == "cuda":
print("-----CUDA code-----") dev_module = fadd.imported_modules[0]
print(dev_module.get_source()) print("-----GPU code-----")
print(dev_module.get_source())
else:
print(fadd.get_source())
###################################################################### ######################################################################
# .. note:: Code Specialization # .. note:: Code Specialization
...@@ -179,8 +190,9 @@ from tvm.contrib import cc ...@@ -179,8 +190,9 @@ from tvm.contrib import cc
from tvm.contrib import util from tvm.contrib import util
temp = util.tempdir() temp = util.tempdir()
fadd_cuda.save(temp.relpath("myadd.o")) fadd.save(temp.relpath("myadd.o"))
fadd_cuda.imported_modules[0].save(temp.relpath("myadd.ptx")) if tgt == "cuda":
fadd.imported_modules[0].save(temp.relpath("myadd.ptx"))
cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")]) cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
print(temp.listdir()) print(temp.listdir())
...@@ -201,8 +213,9 @@ print(temp.listdir()) ...@@ -201,8 +213,9 @@ print(temp.listdir())
# re-link them together. We can verify that the newly loaded function works. # re-link them together. We can verify that the newly loaded function works.
# #
fadd1 = tvm.module.load(temp.relpath("myadd.so")) fadd1 = tvm.module.load(temp.relpath("myadd.so"))
fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx")) if tgt == "cuda":
fadd1.import_module(fadd1_dev) fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx"))
fadd1.import_module(fadd1_dev)
fadd1(a, b, c) fadd1(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
...@@ -215,7 +228,7 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) ...@@ -215,7 +228,7 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
# them together with the host code. # them together with the host code.
# Currently we support packing of Metal, OpenCL and CUDA modules. # Currently we support packing of Metal, OpenCL and CUDA modules.
# #
fadd_cuda.export_library(temp.relpath("myadd_pack.so")) fadd.export_library(temp.relpath("myadd_pack.so"))
fadd2 = tvm.module.load(temp.relpath("myadd_pack.so")) fadd2 = tvm.module.load(temp.relpath("myadd_pack.so"))
fadd2(a, b, c) fadd2(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
...@@ -241,16 +254,17 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) ...@@ -241,16 +254,17 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
# The following codeblocks generate opencl code, creates array on opencl # The following codeblocks generate opencl code, creates array on opencl
# device, and verifies the correctness of the code. # device, and verifies the correctness of the code.
# #
fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd") if tgt == "opencl":
print("------opencl code------") fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd")
print(fadd_cl.imported_modules[0].get_source()) print("------opencl code------")
ctx = tvm.cl(0) print(fadd_cl.imported_modules[0].get_source())
n = 1024 ctx = tvm.cl(0)
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) n = 1024
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
fadd_cl(a, b, c) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) fadd_cl(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
###################################################################### ######################################################################
# Summary # Summary
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment