Commit 464ebb13 by Animesh Jain Committed by Zhi

[QNN] Lowering for Depthwise Convolution. (#4351)

parent 2672aad4
......@@ -503,6 +503,8 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {
Expr MakeConcatenate(Expr data, int axis);
Expr MakeRepeat(Expr data, int repeats, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
Expr MakeStack(Expr data, int axis);
......
......@@ -197,6 +197,11 @@ def _conv2d_legalize(attrs, inputs, arg_types):
if not (dilation[0] == 1 and dilation[1] == 1):
return None
# No legalization for depthwise convolutions yet.
groups = attrs.get_int("groups")
if groups != 1:
return None
# Collect the input tensors.
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
data_dtype = data_tensor.dtype
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment