Commit 2c231b5a by Lianmin Zheng Committed by Tianqi Chen

[RELAY] Move Layout to tvm Node system (#2125)

parent 2c7d2d78
...@@ -85,7 +85,7 @@ class Var : public HalideIR::VarExpr { ...@@ -85,7 +85,7 @@ class Var : public HalideIR::VarExpr {
/*! /*!
* \brief Container of constant ineteger (IntImm). * \brief Container of constant integer (IntImm).
* *
* This is used to store and automate type check * This is used to store and automate type check
* attributes that must be constant integer. * attributes that must be constant integer.
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
#include "../nn/layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -25,7 +25,7 @@ bool ResizeRel(const Array<Type>& types, ...@@ -25,7 +25,7 @@ bool ResizeRel(const Array<Type>& types,
const ResizeAttrs* param = attrs.as<ResizeAttrs>(); const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->layout); const Layout in_layout(param->layout);
CHECK(in_layout.convertible(kNCHW)) CHECK(in_layout.Convertible(kNCHW))
<< "Resize only support input layouts that are convertible from NCHW." << "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/relay/op/layout.cc
* \brief Layout expression.
*/
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(LayoutNode);
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout) {
CHECK_EQ(src_layout.ndim(), src.size());
if (src_layout == dst_layout) {
return src;
} else if (!src_layout.defined()) {
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
} else if (!dst_layout.defined()) {
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
}
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from "
<< src_layout << " to " << dst_layout;
std::vector<IndexExpr> dst(dst_layout.ndim());
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_dim = src_layout[i];
if (Layout::IsSuperdim(src_dim)) {
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_dim));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_dim));
int src_minor_pos = src_layout.Indexof(Layout::ToSubdim(src_dim));
int src_factor = src_layout.Subsizeof(src_dim);
int dst_factor = dst_layout.Subsizeof(src_dim);
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
src_dim_size *= src_factor;
}
dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
CHECK_LE(dst_factor, const_src_dim_size[0])
<< "Converting " << Array<IndexExpr>(src)
<< " from " << src_layout
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
}
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
}
}
}
return dst;
}
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout) {
std::vector<IndexExpr> ret(src.size());
for (size_t i = 0; i < src.size(); ++i) {
ret[i] = src[i];
}
return ConvertLayout(ret, src_layout, dst_layout);
}
} // namespace relay
} // namespace tvm
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "layout.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -28,16 +29,16 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -28,16 +29,16 @@ bool Conv2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout); const Layout kernel_layout(param->weight_layout);
CHECK(in_layout.convertible(kNCHW)) CHECK(in_layout.Convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW." << "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW)) CHECK(kernel_layout.Convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW." << "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
Layout out_layout(param->out_layout); Layout out_layout(param->out_layout);
if (!out_layout.defined()) out_layout = in_layout; if (!out_layout.defined()) out_layout = in_layout;
CHECK(out_layout.convertible(kNCHW)) CHECK(out_layout.Convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW." << "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout; << " But got " << out_layout;
...@@ -55,7 +56,7 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -55,7 +56,7 @@ bool Conv2DRel(const Array<Type>& types,
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout); wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape[kernel_layout.indexof('O')] *= param->groups; wshape[kernel_layout.Indexof('O')] *= param->groups;
channels = param->channels; channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
...@@ -177,10 +178,10 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -177,10 +178,10 @@ bool Conv2DTransposeRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout); const Layout kernel_layout(param->weight_layout);
CHECK(in_layout.convertible(kNCHW)) CHECK(in_layout.Convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW." << "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW)) CHECK(kernel_layout.Convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW." << "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include <vector> #include <vector>
#include "../type_relations.h" #include "../type_relations.h"
#include "../op_common.h" #include "../op_common.h"
#include "layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h> #include <topi/nn/pooling.h>
#include <vector> #include <vector>
#include "layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -33,13 +33,13 @@ bool Pool2DRel(const Array<Type>& types, ...@@ -33,13 +33,13 @@ bool Pool2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.contains('H') && layout.contains('W') && CHECK(layout.Contains('H') && layout.Contains('W') &&
!layout.contains('h') && !layout.contains('w')) !layout.Contains('h') && !layout.Contains('w'))
<< "Invalid layout " << layout << "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split"; << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.indexof('H'); const auto hidx = layout.Indexof('H');
const auto widx = layout.indexof('W'); const auto widx = layout.Indexof('W');
IndexExpr pad_h, pad_w; IndexExpr pad_h, pad_w;
if (param->padding.size() == 1) { if (param->padding.size() == 1) {
...@@ -102,10 +102,10 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -102,10 +102,10 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
auto padding = param->padding; auto padding = param->padding;
auto ceil_mode = param->ceil_mode; auto ceil_mode = param->ceil_mode;
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.convertible(Layout("NCHW"))) CHECK(layout.Convertible(Layout("NCHW")))
<< "max_pool2d currently only supports layouts that are convertible from NCHW"; << "max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.indexof('h'), -1) << "max_pool2d does not support input split on height"; CHECK_EQ(layout.Indexof('h'), -1) << "max_pool2d does not support input split on height";
CHECK_EQ(layout.indexof('w'), -1) << "max_pool2d does not support input split on width"; CHECK_EQ(layout.Indexof('w'), -1) << "max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)" << "Pool2D only support 4-D input (e.g., NCHW)"
...@@ -240,13 +240,13 @@ bool GlobalPool2DRel(const Array<Type>& types, ...@@ -240,13 +240,13 @@ bool GlobalPool2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.contains('H') && layout.contains('W') && CHECK(layout.Contains('H') && layout.Contains('W') &&
!layout.contains('h') && !layout.contains('w')) !layout.Contains('h') && !layout.Contains('w'))
<< "Invalid layout " << layout << "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split"; << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.indexof('H'); const auto hidx = layout.Indexof('H');
const auto widx = layout.indexof('W'); const auto widx = layout.Indexof('W');
std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]}); std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]});
oshape[hidx] = oshape[widx] = 1; oshape[hidx] = oshape[widx] = 1;
...@@ -264,11 +264,11 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs, ...@@ -264,11 +264,11 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
const auto* param = attrs.as<GlobalPool2DAttrs>(); const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.convertible(Layout("NCHW"))) CHECK(layout.Convertible(Layout("NCHW")))
<< "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; << "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.indexof('h'), -1) CHECK_EQ(layout.Indexof('h'), -1)
<< "global_avg_pool2d does not support input split on height"; << "global_avg_pool2d does not support input split on height";
CHECK_EQ(layout.indexof('w'), -1) CHECK_EQ(layout.Indexof('w'), -1)
<< "global_avg_pool2d does not support input split on width"; << "global_avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include "layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -25,7 +25,7 @@ bool UpSamplingRel(const Array<Type>& types, ...@@ -25,7 +25,7 @@ bool UpSamplingRel(const Array<Type>& types,
const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>(); const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->layout); const Layout in_layout(param->layout);
CHECK(in_layout.convertible(kNCHW)) CHECK(in_layout.Convertible(kNCHW))
<< "UpSampling only support input layouts that are convertible from NCHW." << "UpSampling only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "pattern_util.h" #include "pattern_util.h"
#include "pass_util.h" #include "pass_util.h"
#include "../op/nn/layout.h" #include "../op/layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -378,8 +379,8 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { ...@@ -378,8 +379,8 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout); Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C'); int c_big_axis = data_layout.Indexof('C');
int c_small_axis = data_layout.indexof('c'); int c_small_axis = data_layout.Indexof('c');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
AxesSet data_axes = NullValue<AxesSet>(); AxesSet data_axes = NullValue<AxesSet>();
...@@ -391,7 +392,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { ...@@ -391,7 +392,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
if (weight_layout.indexof('i') < 0 && if (weight_layout.Indexof('i') < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis}; data_axes = {c_big_axis};
...@@ -412,15 +413,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call, ...@@ -412,15 +413,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout); Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C'); int c_big_axis = data_layout.Indexof('C');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.indexof('i'), -1); CHECK_EQ(weight_layout.Indexof('i'), -1);
CHECK(sdata->axes.size() == 1 && CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value); c_big_axis == sdata->axes[0]->value);
int big_oc_axis = weight_layout.indexof('O'); int big_oc_axis = weight_layout.Indexof('O');
int big_ic_axis = weight_layout.indexof('I'); int big_ic_axis = weight_layout.Indexof('I');
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout);
...@@ -779,8 +780,8 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { ...@@ -779,8 +780,8 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
out_layout = Layout(param->data_layout); out_layout = Layout(param->data_layout);
} }
Layout weight_layout(param->weight_layout); Layout weight_layout(param->weight_layout);
int c_big_axis = out_layout.indexof('C'); int c_big_axis = out_layout.Indexof('C');
int c_small_axis = out_layout.indexof('c'); int c_small_axis = out_layout.Indexof('c');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
...@@ -791,8 +792,8 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { ...@@ -791,8 +792,8 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
if (weight_layout.indexof('o') < 0 && if (weight_layout.Indexof('o') < 0 &&
weight_layout.indexof('i') < 0 && weight_layout.Indexof('i') < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
return {c_big_axis}; return {c_big_axis};
...@@ -816,16 +817,16 @@ Expr Conv2DBackwardTransform(const Call& call, ...@@ -816,16 +817,16 @@ Expr Conv2DBackwardTransform(const Call& call,
out_layout = Layout(param->data_layout); out_layout = Layout(param->data_layout);
} }
Layout weight_layout(param->weight_layout); Layout weight_layout(param->weight_layout);
int c_big_axis = out_layout.indexof('C'); int c_big_axis = out_layout.Indexof('C');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.indexof('o'), -1); CHECK_EQ(weight_layout.Indexof('o'), -1);
CHECK_EQ(weight_layout.indexof('i'), -1); CHECK_EQ(weight_layout.Indexof('i'), -1);
CHECK(axes.size() == 1 && CHECK(axes.size() == 1 &&
c_big_axis == axes[0]->value); c_big_axis == axes[0]->value);
int big_oc_axis = weight_layout.indexof('O'); int big_oc_axis = weight_layout.Indexof('O');
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d); CHECK(param->groups == 1 || is_depthwise_conv2d);
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include "../op/nn/layout.h" #include "../op/layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment