Commit 236d7ef0 by eqy Committed by Tianqi Chen

[TOPI][Relay] Fix default `out_dtype` for `conv2d_NCHWc` and Relay (#2707)

parent 8f5c27bd
......@@ -294,6 +294,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_dtype = attrs["out_dtype"]
layout_name = 'layout' if F == sym else 'data_layout'
......@@ -301,7 +302,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
kh, kw = attrs.get_int_tuple("kernel_size")
dtype = data.dtype
out_dtype = dtype if attrs["out_dtype"] == "same" else attrs["out_dtype"]
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
is_depthwise = groups == in_channel and groups == out_channel
# only optimize for NCHW
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment