Unverified Commit 6b840fa9 by Animesh Jain Committed by GitHub

[TOPI x86] Adding unroll_kw config option for depthwise conv2d. (#5197)

parent 54975a3f
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
import tvm import tvm
from tvm import te from tvm import te
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.pad import pad from ..nn.pad import pad
from ..util import get_const_tuple from ..util import get_const_tuple
from ..nn.util import get_pad_tuple from ..nn.util import get_pad_tuple
...@@ -67,6 +67,7 @@ def _fallback_schedule(cfg, wkl): ...@@ -67,6 +67,7 @@ def _fallback_schedule(cfg, wkl):
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
cfg["unroll_kw"] = OtherOptionEntity(False)
def depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype): def depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype):
"""Compute depthwise conv2d with NCHW layout.""" """Compute depthwise conv2d with NCHW layout."""
...@@ -133,6 +134,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, ...@@ -133,6 +134,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
cfg.define_split("tile_ic", in_channel, num_outputs=2) cfg.define_split("tile_ic", in_channel, num_outputs=2)
cfg.define_split("tile_oc", out_channel, num_outputs=2) cfg.define_split("tile_oc", out_channel, num_outputs=2)
cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64) cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
cfg.define_knob("unroll_kw", [True, False])
# get workload and related schedule config # get workload and related schedule config
wkl = _get_workload( wkl = _get_workload(
...@@ -199,6 +201,8 @@ def schedule_depthwise_conv2d_NCHWc(cfg, outs): ...@@ -199,6 +201,8 @@ def schedule_depthwise_conv2d_NCHWc(cfg, outs):
def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out, output): def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out, output):
tile_ow, oc_bn = cfg["tile_ow"].size[-1], cfg["tile_oc"].size[-1] tile_ow, oc_bn = cfg["tile_ow"].size[-1], cfg["tile_oc"].size[-1]
unroll_kw = cfg["unroll_kw"].val
# schedule pad # schedule pad
if isinstance(s[data_vec].op, tvm.te.ComputeOp) \ if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
and "pad" in data_vec.op.tag: and "pad" in data_vec.op.tag:
...@@ -229,6 +233,8 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out ...@@ -229,6 +233,8 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out
_, ic_chunk, oh, ow, ic_block = s[CC].op.axis _, ic_chunk, oh, ow, ic_block = s[CC].op.axis
kh, kw = s[CC].op.reduce_axis kh, kw = s[CC].op.reduce_axis
s[CC].reorder(ic_chunk, oh, kh, kw, ow, ic_block) s[CC].reorder(ic_chunk, oh, kh, kw, ow, ic_block)
if unroll_kw:
s[CC].unroll(kw)
s[CC].vectorize(ic_block) s[CC].vectorize(ic_block)
s[CC].unroll(ow) s[CC].unroll(ow)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment