Commit f8db4af6 by ziheng Committed by Tianqi Chen

[OP] Conv2d and Depthwise Conv2d for Raspberry Pi (#49)

* [TUTORIAL] ImageNet Inference on Raspberry Pi

* Update tvm
parent 0aafbff4
...@@ -39,7 +39,8 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224), ...@@ -39,7 +39,8 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224),
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
image_shape = (3, 224, 224) if image_shape is None:
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape data_shape = (batch_size,) + image_shape
params = {} params = {}
g = graph.create(net) g = graph.create(net)
......
...@@ -108,7 +108,7 @@ def compute_conv2d(attrs, inputs, _): ...@@ -108,7 +108,7 @@ def compute_conv2d(attrs, inputs, _):
assert layout == "NCHW", "only support nchw for now" assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now" assert dilation == (1, 1), "not support dilate now"
if groups == 1: if groups == 1:
out = topi.nn.conv2d_nchw(inputs[0], inputs[1], strides, padding) out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding)
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding) out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding)
else: else:
...@@ -128,6 +128,12 @@ def schedule_conv2d(attrs, outs, target): ...@@ -128,6 +128,12 @@ def schedule_conv2d(attrs, outs, target):
return topi.cuda.schedule_conv2d_nchw(outs) return topi.cuda.schedule_conv2d_nchw(outs)
return topi.cuda.schedule_depthwise_conv2d_nchw(outs) return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
# naive schedule # naive schedule
if tvm.target.current_target() == tvm.target.rasp():
if groups == 1:
return topi.rasp.schedule_conv2d(outs)
return topi.rasp.schedule_depthwise_conv2d(outs)
return tvm.create_schedule([x.op for x in outs]) return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment