Commit 464ebb13 by Animesh Jain Committed by Zhi

[QNN] Lowering for Depthwise Convolution. (#4351)

parent 2672aad4
......@@ -503,6 +503,8 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {
Expr MakeConcatenate(Expr data, int axis);
Expr MakeRepeat(Expr data, int repeats, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
Expr MakeStack(Expr data, int axis);
......
......@@ -39,9 +39,7 @@ namespace qnn {
// relay.op.qnn.conv2d
TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs);
bool QnnConv2DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
......@@ -59,8 +57,13 @@ bool QnnConv2DRel(const Array<Type>& types,
return Conv2DRel<QnnConv2DAttrs>(types, num_inputs, attrs, reporter);
}
// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w
using WorkloadType = std::tuple<int, int, int, int, int>;
bool is_depthwise(const QnnConv2DAttrs* param) {
return param->channels.defined() && tvm::ir::Equal(param->channels, param->groups) &&
param->groups != 1;
}
// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier
using WorkloadType = std::tuple<int, int, int, int, int, int>;
/*
* \brief Get the conv parameters like batch_size, kernel_height etc.
......@@ -84,26 +87,39 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
const auto kernel_shape = get_shape(arg_types[1]);
int out_channels, kernel_h, kernel_w;
int channel_multiplier = -1;
bool depthwise = is_depthwise(param);
if (param->kernel_layout == "OIHW") {
out_channels = get_const_int(kernel_shape[0]);
kernel_h = get_const_int(kernel_shape[2]);
kernel_w = get_const_int(kernel_shape[3]);
if (depthwise) {
channel_multiplier = get_const_int(kernel_shape[1]);
}
} else if (param->kernel_layout == "HWIO") {
kernel_h = get_const_int(kernel_shape[0]);
kernel_w = get_const_int(kernel_shape[1]);
out_channels = get_const_int(kernel_shape[3]);
if (depthwise) {
channel_multiplier = get_const_int(kernel_shape[2]);
}
} else if (param->kernel_layout == "HWOI") {
kernel_h = get_const_int(kernel_shape[0]);
kernel_w = get_const_int(kernel_shape[1]);
out_channels = get_const_int(kernel_shape[2]);
if (depthwise) {
channel_multiplier = get_const_int(kernel_shape[3]);
}
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout";
}
return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w);
return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w,
channel_multiplier);
}
/*
* \brief Fallback to simpler lowering for dilation or depthwise conv.
* \brief Fallback to simpler lowering for dilation or grouped conv.
* \param data The input expr.
* \param weight The weight expr.
* \param param The qnn conv2d attributes.
......@@ -167,6 +183,129 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
}
/*
* \brief Calculates the second term in the qnn.conv2d depthwise lowering sequence.
* \param padded_data The padded data expr.
* \param param The qnn conv2d attributes.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \param channel_multiplier The channel/depth multiplier.
* \return The sequence of Relay operators for term2.
* \note The term2 looks like this
*
* Sigma(r, s) zp_w * Qa(n, oc/cm, oh + r, ow + s)
*
* Second term is not directly representable by one Relay operator.
* However, deeper analysis shows that we can reduce r,s using avg_pool2d,
* followed by repeat on the C axis by cm times.
*/
Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
int kernel_w, int channel_multiplier) {
// Constant Expr for the kernel zero point.
auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);
auto casted_t2 = Cast(padded_data, Int(32));
// We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
// Since, this is integer division (floor), we can first multiply the data by the pool_size and
// then perform avg_pool2d. Reversing this causes inaccuracy due to floor division. If the
// pool_size is 1x1, we don't need avg_pool2d.
auto reduced_t2 = casted_t2;
if (kernel_h * kernel_w != 1) {
auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
Array<IndexExpr> padding({0, 0});
reduced_t2 =
AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout,
false, // ceil_mode
false); // count_include_pad
}
auto multiplied_t2 = reduced_t2;
if (param->kernel_zero_point != 1) {
multiplied_t2 = Multiply(zp_kernel, reduced_t2);
}
// Reduce the C dimension. Find the dimension.
int axis_t2 = 0;
if (param->data_layout == "NCHW") {
axis_t2 = 1;
} else if (param->data_layout == "NHWC") {
axis_t2 = 3;
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
auto repeated_t2 = multiplied_t2;
if (channel_multiplier != 1) {
repeated_t2 = MakeRepeat(multiplied_t2, channel_multiplier, axis_t2);
}
return repeated_t2;
}
/*
* \brief Calculates the third term in the qnn.conv2d depthwise lowering sequence.
* \param weight The weight expr.
* \param param The qnn conv2d attributes.
* \param out_channels The number of output channels.
* \param channel_multiplier The channel/depth multiplier.
* \return The sequence of Relay operatos for term3.
* \note The term3 looks like this
*
* Sigma(r, s) zp_a * Qw(oc/m, oc%m, r, s)
*
* This can be achieved by calling reduce on r and s axis. The tensor can be then reshaped to
* (1, oc, 1, 1) as (oc/m, oc%m) are just contiguous memory locations.
*/
Expr DepthwiseConv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels,
int channel_multiplier) {
// Constant expr for input zero point.
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
// Find which dimensions are R, S.
Array<Integer> axes_t3;
if (param->kernel_layout == "OIHW") {
// For OIHW kernel layout, HW are reduce axis
axes_t3 = {2, 3};
} else if (param->kernel_layout == "HWIO") {
axes_t3 = {0, 1};
} else if (param->kernel_layout == "HWOI") {
axes_t3 = {0, 1};
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout";
}
auto reduced_t3 = Sum(Cast(weight, Int(32)), axes_t3, false, false);
// Find the newshape depending on NCHW/NHWC layout.
Array<Integer> newshape;
if (param->data_layout == "NCHW") {
newshape = {1, out_channels * channel_multiplier, 1, 1};
} else if (param->data_layout == "NHWC") {
newshape = {1, 1, 1, out_channels * channel_multiplier};
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
auto reshaped_t3 = Reshape(reduced_t3, newshape);
if (param->input_zero_point == 1) {
return reshaped_t3;
}
return Multiply(zp_data, reshaped_t3);
}
/*
* \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence.
* \param param The qnn conv2d attributes.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(r, s) zp_a * zp_w
*/
Expr DepthwiseConv2DFourthTerm(const QnnConv2DAttrs* param, int kernel_h, int kernel_w) {
int scalar_term4 = param->input_zero_point * param->kernel_zero_point * kernel_h * kernel_w;
return MakeConstantScalar(Int(32), scalar_term4);
}
/*
* \brief Calculates the first term in the qnn.conv2d lowering sequence.
* \param data The input expr.
* \param weight The weight expr.
......@@ -210,7 +349,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
// We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
// Since, this is integer division (floor), we can first multiply the data by the pool_size and
// then perform avg_pool2d. Reversing this causes inaccuracy due to floor division.
auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
Array<IndexExpr> padding({0, 0});
// Reduce the C dimension. Find the dimension.
......@@ -223,11 +361,12 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
// Keep dims true to retain 4D tensor
auto reduced_c_t2 = Sum(scaled_hw_t2, axes_t2, true, false);
auto reduced_c_t2 = Sum(casted_t2, axes_t2, true, false);
// If the pool_size is 1x1, we don't need avg_pool2d.
auto reduced_t2 = reduced_c_t2;
if (kernel_h * kernel_w != 1) {
reduced_c_t2 = Multiply(reduced_c_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
reduced_t2 =
AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, padding, param->data_layout,
false, // ceil_mode
......@@ -245,7 +384,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
* \brief Calculates the third term in the qnn.conv2d lowering sequence.
* \param weight The weight expr.
* \param param The qnn conv2d attributes.
* \param batch_size The batch size.
* \param out_channels The number of output channels.
* \return The sequence of Relay operatos for term3.
* \note The term3 looks like this
......@@ -256,8 +394,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
* a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
* format.
*/
Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_size,
int out_channels) {
Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels) {
// Constant expr for input zero point.
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
......@@ -278,9 +415,9 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
// Find the newshape depending on NCHW/NHWC layout.
Array<Integer> newshape;
if (param->data_layout == "NCHW") {
newshape = {batch_size, out_channels, 1, 1};
newshape = {1, out_channels, 1, 1};
} else if (param->data_layout == "NHWC") {
newshape = {batch_size, 1, 1, out_channels};
newshape = {1, 1, 1, out_channels};
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
......@@ -295,7 +432,6 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
/*
* \brief Calculates the fourth term in the qnn.conv2d lowering sequence.
* \param param The qnn conv2d attributes.
* \param batch_size The batch size.
* \param in_channels The number of input channels.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
......@@ -305,8 +441,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
* Sigma(c,r,s) zp_a * zp_w
*
*/
Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int batch_size, int in_channels, int kernel_h,
int kernel_w) {
Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int in_channels, int kernel_h, int kernel_w) {
int scalar_term4 =
param->input_zero_point * param->kernel_zero_point * in_channels * kernel_h * kernel_w;
return MakeConstantScalar(Int(32), scalar_term4);
......@@ -391,7 +526,20 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
* gives an opportunity to reuse alter_op_layout infrastructure.
* 3) For dilated conv, in current lowering, we need dilated pool. So as
* a workaround, we fall back to simpler lowering using int32 conv if
* the conv is dilated. We fallback also in case of depthwise conv.
* the conv is dilated. We fallback also in case of grouped conv.
*
* For depthwise, we can similarly unroll the computation. The intial compute is as follows
* wehere cm = channel_multiplier
*
* Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w)
* * (Qa(n, oc/cm, oh + r, ow + s) - zp_a)
*
* This can be written as
*
* Sigma(r, s) Qw(oc/m, oc%/m, r, s) * Qa(n, oc/cm, oh + r, ow + s)
* - Sigma(r, s) zp_w * Qa(n, oc/cm, oh + r, ow + s)
* - Sigma(r, s) zp_a * Qw(oc/m, oc%m, r, s)
* - Sigma(r, s) zp_a * zp_w
*
* The whole process can be broken down into following steps
* * Assertion checks for existing support, fallback if necessary
......@@ -417,23 +565,33 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
param->kernel_layout == "HWOI")
<< "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout.";
int batch_size, in_channels, out_channels, kernel_h, kernel_w;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
GetWorkload(arg_types, param);
// Fallback to int32 conv if there is dilation or depthwise conv2d
// Fallback to int32 conv if there is dilation or grouped conv2d
CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) {
return Conv2DFallBack(data, weight, param);
} else if (is_depthwise(param)) {
CHECK_NE(channel_multiplier, -1);
auto padded_data = Conv2DPadInput(data, param);
auto term1 = Conv2DFirstTerm(padded_data, weight, param);
auto term2 =
DepthwiseConv2DSecondTerm(padded_data, param, kernel_h, kernel_w, channel_multiplier);
auto term3 = DepthwiseConv2DThirdTerm(weight, param, out_channels, channel_multiplier);
auto term4 = DepthwiseConv2DFourthTerm(param, kernel_h, kernel_w);
return Conv2DCombineTerms(term1, term2, term3, term4, param);
}
auto padded_data = Conv2DPadInput(data, param);
auto term1 = Conv2DFirstTerm(padded_data, weight, param);
auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
auto term3 = Conv2DThirdTerm(weight, param, out_channels);
auto term4 = Conv2DFourthTerm(param, in_channels, kernel_h, kernel_w);
return Conv2DCombineTerms(term1, term2, term3, term4, param);
}
......
......@@ -42,7 +42,9 @@ def get_ref_func(data,
dilation,
data_layout,
kernel_layout,
out_dtype):
out_dtype,
groups,
channels=None):
casted_data = relay.op.cast(data, "int32")
casted_kernel = relay.op.cast(kernel, "int32")
shifted_data = relay.op.subtract(casted_data,
......@@ -54,6 +56,8 @@ def get_ref_func(data,
padding=padding,
strides=strides,
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
out_dtype=out_dtype,
data_layout=data_layout,
......@@ -74,7 +78,9 @@ def get_qnn_func(data,
dilation,
data_layout,
kernel_layout,
out_dtype):
out_dtype,
groups,
channels=None):
func = relay.qnn.op.conv2d(
data, kernel,
input_zero_point=input_zero_point,
......@@ -86,6 +92,8 @@ def get_qnn_func(data,
dilation=dilation,
padding=padding,
out_dtype=out_dtype,
groups=groups,
channels=channels,
data_layout=data_layout,
kernel_layout=kernel_layout)
......@@ -107,7 +115,9 @@ def get_funcs(data_shape,
dilation,
data_layout,
kernel_layout,
out_dtype):
out_dtype,
groups=1,
channels=None):
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
kernel = relay.var("kernel", shape=kernel_shape,
......@@ -124,8 +134,11 @@ def get_funcs(data_shape,
dilation,
data_layout,
kernel_layout,
out_dtype)
out_dtype,
groups,
channels)
ref_func = run_infer_type(ref_func)
ref_func = relay.Module.from_expr(ref_func)
qnn_func = get_qnn_func(data,
kernel,
input_zero_point,
......@@ -138,7 +151,9 @@ def get_funcs(data_shape,
dilation,
data_layout,
kernel_layout,
out_dtype)
out_dtype,
groups,
channels)
return (ref_func, qnn_func)
def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
......@@ -151,14 +166,14 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
if data_dtype == "uint8":
low = 0
high = 255
golden_data = np.random.random_integers(low=low, high=high,
golden_data = np.random.randint(low=low, high=high,
size=data_shape).astype(data_dtype)
low = -128
high = 127
if kernel_dtype == "uint8":
low = 0
high = 255
golden_weight = np.random.random_integers(low=low, high=high,
golden_weight = np.random.randint(low=low, high=high,
size=kernel_shape).astype(kernel_dtype)
return (golden_data, golden_weight)
......@@ -512,7 +527,7 @@ def test_const_folding():
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
golden_weight = np.random.random_integers(low=0, high=255,
golden_weight = np.random.randint(low=0, high=255,
size=kernel_shape).astype(kernel_dtype)
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
......@@ -529,7 +544,8 @@ def test_const_folding():
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
out_dtype="int32",
groups=1)
folded_mod = transform.FoldConstant()(qnn_func)
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()
......@@ -724,6 +740,112 @@ def test_broadcast_layout():
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
def test_depthwise_depth_multiplier():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input, NCHW and OIHW
# Depthwise multiplier = 1
data_shape = (2, 4, 16, 16)
data_dtype = 'uint8'
kernel_shape = (4, 1, 3, 3)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(3, 3),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32",
groups=4,
channels=4)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# Depthwise multiplier = 2
data_shape = (10, 4, 16, 16)
data_dtype = 'uint8'
kernel_shape = (4, 2, 3, 3)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(3, 3),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32",
groups=8,
channels=8)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# uint8 input, NHWC and HWOI
# Depthwise multiplier = 1
data_shape = (2, 16, 16, 4)
data_dtype = 'uint8'
kernel_shape = (3, 3, 4, 1)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(3, 3),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWOI",
out_dtype="int32",
groups=4,
channels=4)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# Depthwise multiplier = 2
data_shape = (2, 16, 16, 4)
data_dtype = 'uint8'
kernel_shape = (3, 3, 4, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(3, 3),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWOI",
out_dtype="int32",
groups=8,
channels=8)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
if __name__ == "__main__":
test_no_zero_point()
test_input_zero_point()
......@@ -738,3 +860,4 @@ if __name__ == "__main__":
test_broadcast_layout()
test_tflite_output_multiplier_greater_than_one()
test_tflite_anistropic_strides()
test_depthwise_depth_multiplier()
......@@ -197,6 +197,11 @@ def _conv2d_legalize(attrs, inputs, arg_types):
if not (dilation[0] == 1 and dilation[1] == 1):
return None
# No legalization for depthwise convolutions yet.
groups = attrs.get_int("groups")
if groups != 1:
return None
# Collect the input tensors.
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
data_dtype = data_tensor.dtype
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment