Commit cdee6a79 by Lianmin Zheng Committed by Tianqi Chen

register depthwise conv2d as generic function (#1108)

parent 8eebf5f6
......@@ -9,6 +9,7 @@ from .util import get_pad_tuple
from ..util import simplify
@tvm.target.generic_func
def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
"""Depthwise convolution nchw forward operator.
......@@ -63,6 +64,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
return Output
@tvm.target.generic_func
def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
"""Depthwise convolution nhwc forward operator.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment