Commit 83b24b5b by Junru Shao Committed by Tianqi Chen

[TOPI] Minor fix in the LSTM recipe (#2131)

parent 7858a1ea
"""LSTM Example, still work in progress..""" """LSTM Example, still work in progress.."""
import tvm import tvm
import time
import os import os
import argparse
from tvm.contrib import nvcc from tvm.contrib import nvcc
import numpy as np import numpy as np
...@@ -14,16 +12,19 @@ DETECT_GLOBAL_BARRIER = PERSIST_KERNEL ...@@ -14,16 +12,19 @@ DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
SKIP_CHECK = False SKIP_CHECK = False
UNROLL_WLOAD = True UNROLL_WLOAD = True
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf.""" """Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx") ptx = nvcc.compile_cuda(code, target="ptx")
return ptx return ptx
def write_code(code, fname): def write_code(code, fname):
with open(fname, "w") as f: with open(fname, "w") as f:
f.write(code) f.write(code)
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_postproc(code): def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"): if not os.path.exists("perf"):
...@@ -33,16 +34,16 @@ def tvm_callback_cuda_postproc(code): ...@@ -33,16 +34,16 @@ def tvm_callback_cuda_postproc(code):
code = open("perf/%s_manual.cu" % TASK).read() code = open("perf/%s_manual.cu" % TASK).read()
return code return code
def lstm(): def lstm():
if not PERSIST_KERNEL: if not PERSIST_KERNEL:
raise ValueError("Non persist LSTM not yet supported") raise ValueError("Non persist LSTM not yet supported")
detect_global_barrier = DETECT_GLOBAL_BARRIER
num_thread_y = 8 num_thread_y = 8
num_thread_x = 16 * 3 / 2 num_thread_x = 16 * 3 // 2
num_sm = 24 num_sm = 24
n_num_step = 128 n_num_step = 128
num_step = tvm.var('num_step') num_step = tvm.var('num_step')
num_hidden = 1152 / 2 num_hidden = 1152 // 2
batch_size = 1 batch_size = 1
# Global transition matrix # Global transition matrix
# Input hidden channel can be pre-caculated by a gemm # Input hidden channel can be pre-caculated by a gemm
...@@ -165,11 +166,9 @@ def lstm(): ...@@ -165,11 +166,9 @@ def lstm():
flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
ctx.sync() ctx.sync()
# measure time cost of second step. # measure time cost of second step.
tstart = time.time() evaluator = flstm.time_evaluator(flstm.entry_name, ctx, 1, repeat=1000)
flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) eval_result = evaluator(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
ctx.sync() print("Time cost=%g" % eval_result.mean)
tgap = time.time() - tstart
print("Time cost=%g" % tgap)
# set unroll_explicit for more readable code. # set unroll_explicit for more readable code.
with tvm.build_config( with tvm.build_config(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment