Commit 860adec8 by masahi Committed by Tianqi Chen

Initial NHWC layout support (#376)

* initial NHWC layout support

* remove layout param from softmax

* more nhwc support

* fix typo

* add nhwc layout test

* fix lint

* update tvm

* update for c++ topi

* fix lint

* update tvm
parent 68c03944
...@@ -259,10 +259,19 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> { ...@@ -259,10 +259,19 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> { struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale; int scale;
int layout;
DMLC_DECLARE_PARAMETER(UpSamplingParam) { DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale) DMLC_DECLARE_FIELD(scale)
.describe("upsampling scaling factor"); .describe("upsampling scaling factor");
DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
} }
}; };
......
...@@ -29,7 +29,6 @@ reg.register_schedule("pad", _fschedule_broadcast) ...@@ -29,7 +29,6 @@ reg.register_schedule("pad", _fschedule_broadcast)
reg.register_pattern("pad", OpPattern.INJECTIVE) reg.register_pattern("pad", OpPattern.INJECTIVE)
# softmax
@reg.register_schedule("softmax") @reg.register_schedule("softmax")
def schedule_softmax(_, outs, target): def schedule_softmax(_, outs, target):
"""Schedule definition of softmax""" """Schedule definition of softmax"""
...@@ -77,17 +76,18 @@ def compute_conv2d(attrs, inputs, _): ...@@ -77,17 +76,18 @@ def compute_conv2d(attrs, inputs, _):
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
channels = attrs.get_int("channels") channels = attrs.get_int("channels")
layout = attrs["layout"] layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now" assert layout == "NCHW" or layout == "NHWC"
assert dilation == (1, 1), "not support dilate now" assert dilation == (1, 1), "not support dilate now"
if groups == 1: if groups == 1:
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding) out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, layout)
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding) out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding)
else: else:
raise ValueError("not support arbitrary group number for now") raise ValueError("not support arbitrary group number for now")
if attrs.get_bool("use_bias"): if attrs.get_bool("use_bias"):
bias = inputs[2] bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2) expand_axis = 1 if layout == "NCHW" else 0
bias = topi.expand_dims(bias, axis=expand_axis, num_newaxis=2)
out = topi.broadcast_add(out, bias) out = topi.broadcast_add(out, bias)
return out return out
...@@ -95,9 +95,12 @@ def compute_conv2d(attrs, inputs, _): ...@@ -95,9 +95,12 @@ def compute_conv2d(attrs, inputs, _):
def schedule_conv2d(attrs, outs, target): def schedule_conv2d(attrs, outs, target):
"""Schedule definition of conv2d""" """Schedule definition of conv2d"""
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
layout = attrs["layout"]
with tvm.target.create(target): with tvm.target.create(target):
if groups == 1: if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
return topi.generic.schedule_depthwise_conv2d_nchw(outs) return topi.generic.schedule_depthwise_conv2d_nchw(outs)
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -178,7 +181,9 @@ reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) ...@@ -178,7 +181,9 @@ reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
def compute_upsampling(attrs, inputs, _): def compute_upsampling(attrs, inputs, _):
"""Compute definition of upsampling""" """Compute definition of upsampling"""
scale = attrs.get_int("scale") scale = attrs.get_int("scale")
return topi.nn.upsampling(inputs[0], scale) layout = attrs["layout"]
assert layout == "NCHW" or layout == "NHWC"
return topi.nn.upsampling(inputs[0], scale, layout)
@reg.register_schedule("upsampling") @reg.register_schedule("upsampling")
def schedule_upsampling(_, outs, target): def schedule_upsampling(_, outs, target):
......
...@@ -279,8 +279,7 @@ NNVM_REGISTER_OP(softmax) ...@@ -279,8 +279,7 @@ NNVM_REGISTER_OP(softmax)
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed); const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported"; return Array<Tensor>{ topi::nn::softmax(inputs[0], param.axis) };
return Array<Tensor>{ topi::nn::softmax(inputs[0]) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
......
...@@ -91,10 +91,11 @@ NNVM_REGISTER_OP(max_pool2d) ...@@ -91,10 +91,11 @@ NNVM_REGISTER_OP(max_pool2d)
auto strides = ShapeToArray(param.strides); auto strides = ShapeToArray(param.strides);
auto padding = ShapeToArray(param.padding); auto padding = ShapeToArray(param.padding);
auto ceil_mode = param.ceil_mode; auto ceil_mode = param.ceil_mode;
CHECK_EQ(param.layout, kNCHW) CHECK(param.layout == kNCHW || param.layout == kNHWC) << "Unsupported layout";
<< "max_pool2d currently only supports NCHW layout"; std::string layout = (param.layout == kNCHW ? "NCHW" : "NHWC");
return Array<Tensor>{ return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kMaxPool, ceil_mode) }; topi::nn::pool(inputs[0], pool_size, strides, padding, \
topi::nn::kMaxPool, ceil_mode, layout) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -152,10 +153,11 @@ NNVM_REGISTER_OP(avg_pool2d) ...@@ -152,10 +153,11 @@ NNVM_REGISTER_OP(avg_pool2d)
auto strides = ShapeToArray(param.strides); auto strides = ShapeToArray(param.strides);
auto padding = ShapeToArray(param.padding); auto padding = ShapeToArray(param.padding);
auto ceil_mode = param.ceil_mode; auto ceil_mode = param.ceil_mode;
CHECK_EQ(param.layout, kNCHW) CHECK(param.layout == kNCHW || param.layout == kNHWC) << "Unsupported layout";
<< "avg_pool2d currently only supports NCHW layout"; std::string layout = (param.layout == kNCHW ? "NCHW" : "NHWC");
return Array<Tensor>{ return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kAvgPool, ceil_mode) }; topi::nn::pool(inputs[0], pool_size, strides, padding, \
topi::nn::kAvgPool, ceil_mode, layout) };
}) })
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
......
...@@ -24,9 +24,11 @@ inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs, ...@@ -24,9 +24,11 @@ inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[0]; TShape dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false; if (dshape.ndim() == 0) return false;
dshape = ConvertLayout(dshape, param.layout, kNCHW);
TShape oshape = dshape; TShape oshape = dshape;
oshape[2] = oshape[2] * param.scale; oshape[2] = oshape[2] * param.scale;
oshape[3] = oshape[3] * param.scale; oshape[3] = oshape[3] * param.scale;
oshape = ConvertLayout(oshape, kNCHW, param.layout);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true; return true;
} }
......
import numpy as np
import tvm
from tvm.contrib import graph_runtime as runtime
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
def get_sym(layout, channels):
data = sym.Variable(name="data")
data = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1), layout=layout, use_bias=True)
data = sym.max_pool2d(data=data, pool_size=(2, 2), strides=(2, 2), layout=layout)
data = sym.upsampling(data=data, scale=2, layout=layout)
softmax_axis = 1
if layout == "NHWC":
softmax_axis = 3
data = sym.softmax(data=data, axis=softmax_axis)
return data
def build_and_run(sym, params, data, out_shape):
ctx = tvm.cpu(0)
graph, lib, params = nnvm.compiler.build(sym, "llvm", shape={"data":data.shape}, params=params)
module = runtime.create(graph, lib, ctx)
module.set_input(**params)
module.set_input("data", data)
module.run()
out = module.get_output(0, tvm.nd.empty(out_shape))
return out.asnumpy()
def test_nhwc():
data_shape = (1, 3, 224, 224)
out_channel = 8
nchw_sym = get_sym("NCHW", out_channel)
nhwc_sym = get_sym("NHWC", out_channel)
conv_weight = np.random.uniform(-1, 1, (out_channel, 3, 3, 3)).astype(np.float32)
conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32)
nchw_params = {
"conv2d0_weight" : tvm.nd.array(conv_weight, ctx=tvm.cpu(0)),
"conv2d0_bias" : tvm.nd.array(conv_bias, ctx=tvm.cpu(0))
}
nhwc_params = {
"conv2d1_weight" : tvm.nd.array(conv_weight.transpose(2, 3, 1, 0), ctx=tvm.cpu(0)),
"conv2d1_bias" : tvm.nd.array(conv_bias, ctx=tvm.cpu(0))
}
data = np.random.uniform(-1, 1, data_shape).astype(np.float32)
oshape = (1, out_channel, 224, 224)
oshape_nhwc = (1, 224, 224, out_channel)
nchw_output = build_and_run(nchw_sym, nchw_params, data, oshape)
nhwc_output = build_and_run(nhwc_sym, nhwc_params, data.transpose(0, 2, 3, 1), oshape_nhwc)
np.testing.assert_allclose(nchw_output, nhwc_output.transpose(0, 3, 1, 2), rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
test_nhwc()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment