Unverified Commit af974c34 by Tianqi Chen Committed by GitHub

[RELAY][OP] Fix conv2d NHWC type inference. (#2019)

parent 42dc24a3
...@@ -35,6 +35,8 @@ struct NodeTypeChecker { ...@@ -35,6 +35,8 @@ struct NodeTypeChecker {
// It can be turned off, but will make non strict checking. // It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI // TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
// always allow nullptr.
if (sptr == nullptr) return true;
return sptr->derived_from<ContainerType>(); return sptr->derived_from<ContainerType>();
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
...@@ -46,7 +48,7 @@ struct NodeTypeChecker { ...@@ -46,7 +48,7 @@ struct NodeTypeChecker {
template<typename T> template<typename T>
struct NodeTypeChecker<Array<T> > { struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) { static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return true;
if (!sptr->is_type<ArrayNode>()) return false; if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr); ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) { for (const auto& p : n->data) {
...@@ -64,7 +66,7 @@ struct NodeTypeChecker<Array<T> > { ...@@ -64,7 +66,7 @@ struct NodeTypeChecker<Array<T> > {
template<typename V> template<typename V>
struct NodeTypeChecker<Map<std::string, V> > { struct NodeTypeChecker<Map<std::string, V> > {
static inline bool Check(Node* sptr) { static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return true;
if (!sptr->is_type<StrMapNode>()) return false; if (!sptr->is_type<StrMapNode>()) return false;
StrMapNode* n = static_cast<StrMapNode*>(sptr); StrMapNode* n = static_cast<StrMapNode*>(sptr);
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
...@@ -83,7 +85,7 @@ struct NodeTypeChecker<Map<std::string, V> > { ...@@ -83,7 +85,7 @@ struct NodeTypeChecker<Map<std::string, V> > {
template<typename K, typename V> template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > { struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) { static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return true;
if (!sptr->is_type<MapNode>()) return false; if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr); MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
......
...@@ -150,5 +150,10 @@ TVM_REGISTER_NODE_TYPE(OpNode) ...@@ -150,5 +150,10 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name; return static_cast<const OpNode*>(n)->name;
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<OpNode>([](const OpNode* node, tvm::IRPrinter* p) {
p->stream << "Op(" << node->name << ")";
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -21,7 +21,6 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -21,7 +21,6 @@ bool Conv2DRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>(); const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false; if (data == nullptr) return false;
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW"); static const Layout kOIHW("OIHW");
...@@ -42,6 +41,9 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -42,6 +41,9 @@ bool Conv2DRel(const Array<Type>& types,
<< "Conv only support output layouts that are convertible from NCHW." << "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout; << " But got " << out_layout;
std::vector<IndexExpr> dshape_nchw = ConvertLayout(
data->shape, in_layout, kNCHW);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x; IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined // infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) { if (param->kernel_size.defined() && param->channels.defined()) {
...@@ -49,7 +51,7 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -49,7 +51,7 @@ bool Conv2DRel(const Array<Type>& types,
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape( std::vector<IndexExpr> wshape(
{param->channels / param->groups, {param->channels / param->groups,
data->shape[1] / param->groups, dshape_nchw[1] / param->groups,
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout); wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
...@@ -78,16 +80,16 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -78,16 +80,16 @@ bool Conv2DRel(const Array<Type>& types,
<< " channels=" << param->channels << " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape); << " wshape=" << Array<IndexExpr>(wshape);
} }
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1])); CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
channels = wshape[0]; channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
} }
// dilation // dilation
std::vector<IndexExpr> oshape({data->shape[0], channels, 0, 0}); std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1; oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1; oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) { if (out_dtype.bits() == 0) {
out_dtype = data->dtype; out_dtype = data->dtype;
...@@ -183,7 +185,9 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -183,7 +185,9 @@ bool Conv2DTransposeRel(const Array<Type>& types,
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x; IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
const auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
// infer weight if the kernel_size and channels are defined // infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) { if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->kernel_size.size(), 2);
......
...@@ -495,9 +495,7 @@ inline std::vector<IndexExpr> ConvertLayout( ...@@ -495,9 +495,7 @@ inline std::vector<IndexExpr> ConvertLayout(
IndexExpr src_dim_size = src[i]; IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) { if (src_minor_pos >= 0) {
const int64_t* minor_size = as_const_int(src[src_minor_pos]); CHECK(is_const_int(src[src_minor_pos], src_factor))
CHECK(minor_size == nullptr &&
src_factor == minor_size[0])
<< "src shape " << Array<IndexExpr>(src) << "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout " << " does not agree with layout "
<< src_layout; << src_layout;
......
...@@ -32,9 +32,9 @@ def test_conv2d_infer_type(): ...@@ -32,9 +32,9 @@ def test_conv2d_infer_type():
# Infer with a different layout # Infer with a different layout
n, c, h, w = 4, 32, 224, 224 n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) x = relay.var("x", relay.TensorType((n//4, c//4, h, w, 4, 4), "int8"))
w = relay.var("w") wt = relay.var("w")
y = relay.nn.conv2d(x, w, y = relay.nn.conv2d(x, wt,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
channels=16, channels=16,
...@@ -47,6 +47,21 @@ def test_conv2d_infer_type(): ...@@ -47,6 +47,21 @@ def test_conv2d_infer_type():
assert yy.args[1].checked_type == relay.TensorType( assert yy.args[1].checked_type == relay.TensorType(
(4, 8, 3, 3, 4, 4), "int8") (4, 8, 3, 3, 4, 4), "int8")
# Infer with NHWC
n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n, h, w, c), "int8"))
wt = relay.var("w")
y = relay.nn.conv2d(x, wt,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
data_layout="NHWC",
out_dtype="int32")
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, h, w, 16), "int32")
def test_conv2d_transpose_infer_type(): def test_conv2d_transpose_infer_type():
# symbolic in batch dimension # symbolic in batch dimension
n, c, h, w = tvm.var("n"), 10, 10, 12 n, c, h, w = tvm.var("n"), 10, 10, 12
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment