Commit 9ac46bea by Leyuan Wang Committed by Tianqi Chen

[TOPI] Fix conv2d for small input channels (#331)

* __init__ updated

* pull request updated

* build_module added

* typo fixed

* another typo fixed

* conv2d gpu scheduler for two layouts moved to tvm

* changes made according to CR

* conv2d_nchw formating updated, conv2d_hwcn tests updated

* lint error fixed

* element wise operator schedule fusing fixed for conv2d

* conv2d_nchw topi test added, all resnet workloads now pass

* conv compute lint error fixed

* fixed python 3 compatibility problem

* conv2d tensor input support added, test typo fixed, ir_pass.Simplify changed to util.get_const_int

* fixed channel numer < 4 error, also made sure other splitting factor woudn't be 0
parent 0ad590c0
......@@ -39,12 +39,12 @@ def schedule_conv2d_small_batch(outs):
vthread = 2
out_filter = min(64, util.get_const_int(Filter.shape[0]))
in_filter = util.get_const_int(Filter.shape[1])
opart2 = out_filter//8
opart2 = max(1, out_filter//8)
ofactor = out_filter
wfactor = block_h
ifactor = in_filter//4
ifactor = max(8, in_filter//4)
sfactor = max(1, ofactor//(opart2*2))
spart = (wfactor + vthread-1) // vthread
spart = max(1, (wfactor + vthread-1) // vthread)
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment