Unverified Commit af974c34 by Tianqi Chen Committed by GitHub

[RELAY][OP] Fix conv2d NHWC type inference. (#2019)

parent 42dc24a3
......@@ -35,6 +35,8 @@ struct NodeTypeChecker {
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
// always allow nullptr.
if (sptr == nullptr) return true;
return sptr->derived_from<ContainerType>();
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
......@@ -46,7 +48,7 @@ struct NodeTypeChecker {
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (sptr == nullptr) return true;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
......@@ -64,7 +66,7 @@ struct NodeTypeChecker<Array<T> > {
template<typename V>
struct NodeTypeChecker<Map<std::string, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (sptr == nullptr) return true;
if (!sptr->is_type<StrMapNode>()) return false;
StrMapNode* n = static_cast<StrMapNode*>(sptr);
for (const auto& kv : n->data) {
......@@ -83,7 +85,7 @@ struct NodeTypeChecker<Map<std::string, V> > {
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (sptr == nullptr) return true;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
......
......@@ -150,5 +150,10 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name;
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<OpNode>([](const OpNode* node, tvm::IRPrinter* p) {
p->stream << "Op(" << node->name << ")";
});
} // namespace relay
} // namespace tvm
......@@ -21,7 +21,6 @@ bool Conv2DRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
......@@ -42,6 +41,9 @@ bool Conv2DRel(const Array<Type>& types,
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
std::vector<IndexExpr> dshape_nchw = ConvertLayout(
data->shape, in_layout, kNCHW);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
......@@ -49,7 +51,7 @@ bool Conv2DRel(const Array<Type>& types,
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape(
{param->channels / param->groups,
data->shape[1] / param->groups,
dshape_nchw[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
......@@ -78,16 +80,16 @@ bool Conv2DRel(const Array<Type>& types,
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1]));
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
std::vector<IndexExpr> oshape({data->shape[0], channels, 0, 0});
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
......@@ -183,7 +185,9 @@ bool Conv2DTransposeRel(const Array<Type>& types,
<< " But got "<< kernel_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
const auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
......
......@@ -495,9 +495,7 @@ inline std::vector<IndexExpr> ConvertLayout(
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
const int64_t* minor_size = as_const_int(src[src_minor_pos]);
CHECK(minor_size == nullptr &&
src_factor == minor_size[0])
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
......
......@@ -32,9 +32,9 @@ def test_conv2d_infer_type():
# Infer with a different layout
n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
w = relay.var("w")
y = relay.nn.conv2d(x, w,
x = relay.var("x", relay.TensorType((n//4, c//4, h, w, 4, 4), "int8"))
wt = relay.var("w")
y = relay.nn.conv2d(x, wt,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
......@@ -47,6 +47,21 @@ def test_conv2d_infer_type():
assert yy.args[1].checked_type == relay.TensorType(
(4, 8, 3, 3, 4, 4), "int8")
# Infer with NHWC
n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n, h, w, c), "int8"))
wt = relay.var("w")
y = relay.nn.conv2d(x, wt,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
data_layout="NHWC",
out_dtype="int32")
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, h, w, 16), "int32")
def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
n, c, h, w = tvm.var("n"), 10, 10, 12
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment