Commit b154e6b9 by Tianqi Chen Committed by GitHub

[NNVM] Initial mixed precision support of conv2d (#1356)

parent 4fb58115
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <nnvm/tuple.h> #include <nnvm/tuple.h>
#include <nnvm/layout.h> #include <nnvm/layout.h>
#include <string> #include <string>
#include "./tensor.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
...@@ -122,6 +123,7 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> { ...@@ -122,6 +123,7 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
std::string layout; std::string layout;
std::string kernel_layout; std::string kernel_layout;
std::string out_layout; std::string out_layout;
int out_dtype;
bool use_bias; bool use_bias;
DMLC_DECLARE_PARAMETER(Conv2DParam) { DMLC_DECLARE_PARAMETER(Conv2DParam) {
...@@ -156,6 +158,11 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> { ...@@ -156,6 +158,11 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively."); "dimensions respectively.");
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
.add_enum("same", -1)
.set_default(-1)
.describe("Output data type, set to explicit type under mixed precision setting");
DMLC_DECLARE_FIELD(use_bias).set_default(true) DMLC_DECLARE_FIELD(use_bias).set_default(true)
.describe("Whether the layer uses a bias vector."); .describe("Whether the layer uses a bias vector.");
} }
......
...@@ -88,6 +88,8 @@ def compute_conv2d(attrs, inputs, _): ...@@ -88,6 +88,8 @@ def compute_conv2d(attrs, inputs, _):
channels = attrs.get_int("channels") channels = attrs.get_int("channels")
layout = attrs["layout"] layout = attrs["layout"]
kernel_layout = attrs["kernel_layout"] kernel_layout = attrs["kernel_layout"]
out_dtype = attrs["out_dtype"]
out_dtype = None if out_dtype == "same" else out_dtype
assert layout == "NCHW" or layout == "NHWC" assert layout == "NCHW" or layout == "NHWC"
(dilation_h, dilation_w) = dilation (dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1: if dilation_h < 1 or dilation_w < 1:
...@@ -100,16 +102,19 @@ def compute_conv2d(attrs, inputs, _): ...@@ -100,16 +102,19 @@ def compute_conv2d(attrs, inputs, _):
kernel = topi.nn.dilate(inputs[1], [1, dilation_h, dilation_w, 1]) kernel = topi.nn.dilate(inputs[1], [1, dilation_h, dilation_w, 1])
if groups == 1: if groups == 1:
out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout) out = topi.nn.conv2d(
inputs[0], kernel, strides, padding, layout, out_dtype=out_dtype)
elif layout == "NCHW" and \ elif layout == "NCHW" and \
groups == get_const_int(inputs[0].shape[1]) and \ groups == get_const_int(inputs[0].shape[1]) and \
groups == channels: groups == channels:
out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding) out = topi.nn.depthwise_conv2d_nchw(
inputs[0], kernel, strides, padding, out_dtype=out_dtype)
elif layout == "NHWC" and \ elif layout == "NHWC" and \
kernel_layout == "HWOI" and \ kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \ groups == get_const_int(inputs[0].shape[3]) and \
groups == channels: groups == channels:
out = topi.nn.depthwise_conv2d_nhwc(inputs[0], kernel, strides, padding) out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], kernel, strides, padding, out_dtype=out_dtype)
else: else:
raise ValueError("not support arbitrary group number for now") raise ValueError("not support arbitrary group number for now")
...@@ -127,6 +132,7 @@ def schedule_conv2d(attrs, outs, target): ...@@ -127,6 +132,7 @@ def schedule_conv2d(attrs, outs, target):
channels = attrs.get_int("channels") channels = attrs.get_int("channels")
layout = attrs["layout"] layout = attrs["layout"]
kernel_layout = attrs["kernel_layout"] kernel_layout = attrs["kernel_layout"]
with tvm.target.create(target): with tvm.target.create(target):
if groups == 1 and layout == "NCHW": if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
......
...@@ -130,6 +130,30 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -130,6 +130,30 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool Conv2DInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if (param.use_bias) {
CHECK_EQ(in_type->size(), 3U) << "Input:[data, weight, bias]";
} else {
CHECK_EQ(in_type->size(), 2U) << "Input:[data, weight]";
}
CHECK_EQ(out_type->size(), 1U);
if (param.out_dtype != -1) {
CHECK(!type_is_none((*in_type)[0]));
for (size_t i = 1; i < in_type->size(); ++i) {
NNVM_ASSIGN_INPUT_TYPE(attrs, *in_type, i, (*in_type)[0]);
}
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_type, 0, param.out_dtype);
} else {
ElemwiseType<-1, 1>(attrs, in_type, out_type);
}
return true;
}
inline bool Conv2DCorrectLayout(const NodeAttrs& attrs, inline bool Conv2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
...@@ -189,7 +213,7 @@ a bias vector is created and added to the outputs. ...@@ -189,7 +213,7 @@ a bias vector is created and added to the outputs.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape) .set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", Conv2DInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout) .set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>) .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
...@@ -214,7 +238,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc) ...@@ -214,7 +238,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape) .set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", Conv2DInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout) .set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>) .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
...@@ -348,7 +372,7 @@ said convolution. ...@@ -348,7 +372,7 @@ said convolution.
- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1]) - **weight**: (in_channels, channels, kernel_size[0], kernel_size[1])
- **bias**: (channels,) - **bias**: (channels,)
- **out**: This depends on the `layout` parameter. Output is 4D array of shape - **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`. v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are calculated as:: out_height and out_width are calculated as::
out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
......
...@@ -32,6 +32,35 @@ def test_conv2d(): ...@@ -32,6 +32,35 @@ def test_conv2d():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_mixed_precision():
x = sym.Variable("x")
dtype = "int8"
out_dtype="int32"
y = sym.conv2d(x,
channels=10,
kernel_size=(3,3),
name="y",
padding=(1,1),
use_bias=False,
out_dtype="int32")
dshape = (1, 3, 18, 18)
kshape = (10, 3, 3, 3)
oshape = (1, 10, 18, 18)
shape_dict = {"x": dshape}
dtype_dict = {"x": dtype}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict, dtype_dict)
m = graph_runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(-127, 127, size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(-127, 127, size=kshape).astype(dtype))
m.run(x=data, y_weight=kernel)
out = m.get_output(0, tvm.nd.empty(oshape, out_dtype))
c_np = topi.testing.conv2d_nchw_python(
data.asnumpy().astype(out_dtype),
kernel.asnumpy().astype(out_dtype), 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_dilated_conv2d(): def test_dilated_conv2d():
dilation = 3 dilation = 3
x = sym.Variable("x") x = sym.Variable("x")
...@@ -167,7 +196,7 @@ def test_avg_pool2d_no_count_pad(): ...@@ -167,7 +196,7 @@ def test_avg_pool2d_no_count_pad():
kh, kw = (4, 4) kh, kw = (4, 4)
sh, sw = (2, 2) sh, sw = (2, 2)
ph, pw = (2, 2) ph, pw = (2, 2)
x = sym.Variable("x") x = sym.Variable("x")
y = sym.avg_pool2d(x, pool_size=(kh, kw), strides=(sw, sw), padding=(ph, pw), y = sym.avg_pool2d(x, pool_size=(kh, kw), strides=(sw, sw), padding=(ph, pw),
name="y", count_include_pad=False) name="y", count_include_pad=False)
...@@ -181,7 +210,7 @@ def test_avg_pool2d_no_count_pad(): ...@@ -181,7 +210,7 @@ def test_avg_pool2d_no_count_pad():
no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
pad_np[np.ix_(*no_zero)] = a_np pad_np[np.ix_(*no_zero)] = a_np
b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
for i in range(oh): for i in range(oh):
for j in range(ow): for j in range(ow):
pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3)) pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3))
...@@ -289,6 +318,7 @@ def test_resize_bilinear(): ...@@ -289,6 +318,7 @@ def test_resize_bilinear():
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
test_mixed_precision()
test_conv2d() test_conv2d()
test_dilated_conv2d() test_dilated_conv2d()
test_grouped_conv2d_nchw() test_grouped_conv2d_nchw()
......
...@@ -168,6 +168,27 @@ def test_conv2d(): ...@@ -168,6 +168,27 @@ def test_conv2d():
layout="NHWC") layout="NHWC")
def test_conv2d_packed():
def check(in_shape,
out_shape,
kernel_shape,
**kwargs):
x = sym.Variable("x", shape=in_shape)
y = sym.conv2d(x, name="y", **kwargs)
sdict = infer_shape(y)
assert(tuple(sdict["y"][0]) == tuple(out_shape))
assert(tuple(sdict["y_weight"][0]) == tuple(kernel_shape))
check((4, 10, 10, 12, 1, 8),
(4, 10, 10, 2, 1, 8),
(2, 12, 3, 3, 8, 8),
channels=8 * 2,
kernel_size=(3,3),
padding=(1,1),
layout="NHWC1n8c",
kernel_layout="OIHW8o8i")
def test_conv2d_transpose(): def test_conv2d_transpose():
def check(in_shape, out_shape, **kwargs): def check(in_shape, out_shape, **kwargs):
x = sym.Variable("x", shape=in_shape) x = sym.Variable("x", shape=in_shape)
...@@ -332,6 +353,7 @@ def test_reduce(): ...@@ -332,6 +353,7 @@ def test_reduce():
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_packed()
test_expand_dims() test_expand_dims()
test_dense() test_dense()
test_matmul() test_matmul()
......
...@@ -27,12 +27,15 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None): ...@@ -27,12 +27,15 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
padding : int or str padding : int or str
Padding size, or ['VALID', 'SAME'] Padding size, or ['VALID', 'SAME']
out_dtype: str, optional
Output data type
Returns Returns
------- -------
Output : tvm.Tensor Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
out_dtype = Input.dtype out_dtype = Input.dtype if out_dtype is None else out_dtype
batch, in_channel, in_height, in_width = Input.shape batch, in_channel, in_height, in_width = Input.shape
filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
...@@ -65,7 +68,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None): ...@@ -65,7 +68,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
@tvm.target.generic_func @tvm.target.generic_func
def depthwise_conv2d_nhwc(Input, Filter, stride, padding): def depthwise_conv2d_nhwc(Input, Filter, stride, padding, out_dtype=None):
"""Depthwise convolution nhwc forward operator. """Depthwise convolution nhwc forward operator.
Parameters Parameters
...@@ -82,11 +85,16 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding): ...@@ -82,11 +85,16 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
padding : int or str padding : int or str
Padding size, or ['VALID', 'SAME'] Padding size, or ['VALID', 'SAME']
out_dtype: str, optional
Output data type
Returns Returns
------- -------
Output : tvm.Tensor Output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel] 4-D with shape [batch, out_height, out_width, out_channel]
""" """
out_dtype = Input.dtype if out_dtype is None else out_dtype
batch, in_height, in_width, in_channel = Input.shape batch, in_height, in_width, in_channel = Input.shape
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
if isinstance(stride, int): if isinstance(stride, int):
...@@ -110,8 +118,9 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding): ...@@ -110,8 +118,9 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
Output = tvm.compute( Output = tvm.compute(
(batch, out_height, out_width, out_channel), (batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum( lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] * (PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier].astype(
Filter[di, dj, c/channel_multiplier, c%channel_multiplier]), out_dtype) *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)),
axis=[di, dj]), axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc") name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output return Output
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment