Commit 83b24b5b by Junru Shao Committed by Tianqi Chen

[TOPI] Minor fix in the LSTM recipe (#2131)

parent 7858a1ea
"""LSTM Example, still work in progress.."""
import tvm
import time
import os
import argparse
from tvm.contrib import nvcc
import numpy as np
......@@ -14,16 +12,19 @@ DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
SKIP_CHECK = False
UNROLL_WLOAD = True
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
......@@ -33,16 +34,16 @@ def tvm_callback_cuda_postproc(code):
code = open("perf/%s_manual.cu" % TASK).read()
return code
def lstm():
if not PERSIST_KERNEL:
raise ValueError("Non persist LSTM not yet supported")
detect_global_barrier = DETECT_GLOBAL_BARRIER
num_thread_y = 8
num_thread_x = 16 * 3 / 2
num_thread_x = 16 * 3 // 2
num_sm = 24
n_num_step = 128
num_step = tvm.var('num_step')
num_hidden = 1152 / 2
num_hidden = 1152 // 2
batch_size = 1
# Global transition matrix
# Input hidden channel can be pre-caculated by a gemm
......@@ -165,11 +166,9 @@ def lstm():
flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
ctx.sync()
# measure time cost of second step.
tstart = time.time()
flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
ctx.sync()
tgap = time.time() - tstart
print("Time cost=%g" % tgap)
evaluator = flstm.time_evaluator(flstm.entry_name, ctx, 1, repeat=1000)
eval_result = evaluator(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
print("Time cost=%g" % eval_result.mean)
# set unroll_explicit for more readable code.
with tvm.build_config(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment