Commit 7d7d035e by masahi Committed by Tianqi Chen

allow fallback path to non imagenet workloads (#886)

parent 28bb0f68
......@@ -66,8 +66,8 @@ def _get_schedule_conv(wkl):
@conv2d.register("cpu")
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
target = tvm.target.current_target(allow_none=False)
if 'avx' in str(target) and layout == 'NCHW':
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
if wkl in _WORKLOADS and 'avx' in str(target) and layout == 'NCHW':
sch = _get_schedule(wkl)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype)
elif layout == 'NCHW':
......@@ -86,6 +86,30 @@ def schedule_conv2d(outs):
s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target(allow_none=False)
def default_schedule(op):
"""NCHW conv2d schedule for non imagenet workloads"""
conv = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, c_pad, h_pad, w_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, c_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, c, h, w = C.op.axis
rc, ry, rx = C.op.reduce_axis
fused = s[C].fuse(n, c)
s[C].parallel(fused)
wo, wi = s[C].split(w, factor=16)
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop
s[C].unroll(rx)
s[C].unroll(ry)
s[C].vectorize(wi)
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -104,6 +128,7 @@ def schedule_conv2d(outs):
if 'conv2d_nchw' in op.tag:
if 'avx' in str(target):
try:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
......@@ -114,7 +139,6 @@ def schedule_conv2d(outs):
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
......@@ -125,28 +149,10 @@ def schedule_conv2d(outs):
sch = _get_schedule(wkl)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec,
kernel, kernel_vec, conv_out, output, outs[0])
except IndexError:
default_schedule(op)
else:
conv = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, c_pad, h_pad, w_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, c_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, c, h, w = C.op.axis
rc, ry, rx = C.op.reduce_axis
fused = s[C].fuse(n, c)
s[C].parallel(fused)
wo, wi = s[C].split(w, factor=16)
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop
s[C].unroll(rx)
s[C].unroll(ry)
s[C].vectorize(wi)
default_schedule(op)
traverse(outs[0].op)
return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment