Commit ee79703c by Yizhi Liu Committed by Lianmin Zheng

[Lang] Layout in TVM node system (#2509)

* move layout.h & layout.cc from relay to tvm

* change ConvertLayout in relay to bijectiveLayout->Forward/backward

* add first test case

* add LayoutAxis

* add LayoutAxis struct and compiles

* simplify BijectiveLayout rule consturct

* polish func name for Layout, move impl to .cc, remove Layout::defined(), add defined() checker

* partially add layout py support

* add layout test cases

* add doc for tvm.layout & tvm.bijective_layout

* fix lint

* fix lint

* fix layout name generation bug

* fix layout typo

* address comments and add topi.layout_transform

* layout.h->data_layout.h, test_lang_layout.py->test_lang_data_layout.py
parent ddc31fd4
...@@ -68,6 +68,7 @@ List of operators ...@@ -68,6 +68,7 @@ List of operators
topi.greater_equal topi.greater_equal
topi.less_equal topi.less_equal
topi.arange topi.arange
topi.layout_transform
topi.image.resize topi.image.resize
...@@ -125,6 +126,7 @@ topi ...@@ -125,6 +126,7 @@ topi
.. autofunction:: topi.greater .. autofunction:: topi.greater
.. autofunction:: topi.less .. autofunction:: topi.less
.. autofunction:: topi.arange .. autofunction:: topi.arange
.. autofunction:: topi.layout_transform
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -674,42 +674,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] ...@@ -674,42 +674,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& outputs) { const Array<Tensor>& outputs) {
const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed); const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed);
return Array<Tensor>{
Layout src_layout(param.src_layout); topi::layout_transform(inputs[0], param.src_layout, param.dst_layout)
Layout dst_layout(param.dst_layout);
if (src_layout == dst_layout) {
return Array<Tensor>{ inputs[0] };
} else if (!src_layout.defined() || !dst_layout.defined()) {
LOG(FATAL) << "cannot convert from/to undefined layout";
}
CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " << param.src_layout
<< " to " << param.dst_layout;
return Array<Tensor> {
topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array<Var>& dst_indices) {
std::vector<Expr> dst_to_src_indices;
for (Layout::LayoutDim src_axis : src_layout) {
int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis));
int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis));
Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::is_superdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::is_subdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<Expr>(dst_to_src_indices);
})
}; };
}) })
.set_support_level(1); .set_support_level(1);
......
...@@ -515,7 +515,7 @@ def decl_buffer(shape, ...@@ -515,7 +515,7 @@ def decl_buffer(shape,
scope="", scope="",
data_alignment=-1, data_alignment=-1,
offset_factor=0): offset_factor=0):
"""Decleare a new symbolic buffer. """Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build. Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout. This is only needed if user want to specify their own buffer layout.
...@@ -587,6 +587,49 @@ def decl_buffer(shape, ...@@ -587,6 +587,49 @@ def decl_buffer(shape,
data, dtype, shape, strides, elem_offset, name, scope, data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor) data_alignment, offset_factor)
def layout(layout_str):
"""Create a layout node from a string.
Parameters
----------
layout_str : str
A layout representation is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).
Returns
-------
layout : Layout
The created layout
"""
return _api_internal._Layout(layout_str)
def bijective_layout(src_layout, dst_layout):
"""Create a bijective layout mapping.
Parameters
----------
src_layout : str or Layout
source layout.
dst_layout : str or Layout
destination layout.
Returns
-------
bijective_layout : BijectiveLayout
The created bijective layout
"""
if isinstance(src_layout, str):
src_layout = layout(src_layout)
if isinstance(dst_layout, str):
dst_layout = layout(dst_layout)
return _api_internal._BijectiveLayout(src_layout, dst_layout)
def _IterVar(dom, name, iter_type, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar """Internal function to create IterVar
......
...@@ -185,3 +185,142 @@ class HybridOp(Operation): ...@@ -185,3 +185,142 @@ class HybridOp(Operation):
def axis(self): def axis(self):
"""Represent axis of IterVar, also defined when it is a HybridOp""" """Represent axis of IterVar, also defined when it is a HybridOp"""
return self.__getattr__("axis") return self.__getattr__("axis")
@register_node
class Layout(NodeBase):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
Do not construct directly, use :any:`layout` instead.
See the documentation of :any:`layout` for more details.
See Also
--------
layout : Declare a layout
"""
def __str__(self):
return self.name
def __repr__(self):
return "Layout(" + self.name + ")"
def __len__(self):
return _api_internal._LayoutNdim(self)
def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Layout index out of range")
return _api_internal._LayoutGetItem(self, index)
def index_of(self, axis):
"""Get the index of an axis
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
index : int
The index of the axis, -1 if not found.
"""
return _api_internal._LayoutIndexOf(self, axis)
def factor_of(self, axis):
"""Get the factor size of the subordinate axis.
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
factor : int
the size of the subordinate-axis of axis (if axis is a primal-axis),
or the size of axis itself (if axis is a subordinate-axis).
Return -1 if axis is not in the layout.
"""
return _api_internal._LayoutFactorOf(self, axis)
@register_node
class BijectiveLayout(NodeBase):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
Do not construct directly, use :any:`bijective_layout` instead.
See the documentation of :any:`bijective_layout` for more details.
See Also
--------
bijective_layout : Declare a bijective layout converter
"""
def forward_index(self, index):
"""Given the indices of the src-layout, infer the dst index.
Parameters
----------
index: Array of Expr
The indices in src-layout.
Returns
-------
dst_index: Array of Expr
The inferred indices in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardIndex(self, index)
def backward_index(self, index):
"""Given the indices of the dst-layout, infer the src index.
Parameters
----------
index: Array of Expr
The indices in dst-layout.
Returns
-------
src_index: Array of Expr
The inferred indices in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardIndex(self, index)
def forward_shape(self, shape):
"""Given the shape of the src-layout, infer the dst shape.
Parameters
----------
shape: Array of Expr
The shape in src-layout.
Returns
-------
dst_shape: Array of Expr
The inferred shape in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardShape(self, shape)
def backward_shape(self, shape):
"""Given the shape of the dst-layout, infer the src shape.
Parameters
----------
shape: Array of Expr
The shape in dst-layout.
Returns
-------
src_shape: Array of Expr
The inferred shape in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardShape(self, shape)
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/data_layout.h>
namespace tvm { namespace tvm {
...@@ -224,6 +225,63 @@ TVM_REGISTER_API("_BufferVStore") ...@@ -224,6 +225,63 @@ TVM_REGISTER_API("_BufferVStore")
.vstore(args[1], args[2]); .vstore(args[1], args[2]);
}); });
TVM_REGISTER_API("_Layout")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LayoutNode::make(args[0]);
});
TVM_REGISTER_API("_LayoutIndexOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Layout()
.IndexOf(LayoutAxis::make(args[1]));
});
TVM_REGISTER_API("_LayoutFactorOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Layout()
.FactorOf(LayoutAxis::make(args[1]));
});
TVM_REGISTER_API("_LayoutNdim")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(args[0].operator Layout().ndim());
});
TVM_REGISTER_API("_LayoutGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
const LayoutAxis& axis = args[0].operator Layout()[args[1]];
*ret = axis.name();
});
TVM_REGISTER_API("_BijectiveLayout")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BijectiveLayoutNode::make(args[0], args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.ForwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.BackwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.ForwardShape(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.BackwardShape(args[1]);
});
TVM_REGISTER_API("_Tensor") TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0], *ret = TensorNode::make(args[0],
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
* \brief Property def of nn operators. * \brief Property def of nn operators.
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/debug.h> #include <tvm/relay/attrs/debug.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <vector> #include <vector>
#include "./type_relations.h" #include "./type_relations.h"
#include "./op_common.h" #include "./op_common.h"
#include "./layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
* \file resize.cc * \file resize.cc
* \brief Image operators * \brief Image operators
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/image/resize.h> #include <topi/image/resize.h>
#include "../layout.h"
#include "../op_common.h" #include "../op_common.h"
namespace tvm { namespace tvm {
...@@ -28,17 +28,18 @@ bool ResizeRel(const Array<Type>& types, ...@@ -28,17 +28,18 @@ bool ResizeRel(const Array<Type>& types,
const ResizeAttrs* param = attrs.as<ResizeAttrs>(); const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->layout); const Layout in_layout(param->layout);
CHECK(in_layout.Convertible(kNCHW)) auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "Resize only support input layouts that are convertible from NCHW." << "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW); auto oshape = layout_converter.ForwardShape(data->shape);
oshape[2] = param->size[0]; oshape.Set(2, param->size[0]);
oshape[3] = param->size[1]; oshape.Set(3, param->size[1]);
// assign output type // assign output type
reporter->Assign(types[1], reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout), TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype)); data->dtype));
return true; return true;
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/relay/op/layout.cc
* \brief Layout expression.
*/
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(LayoutNode);
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout) {
CHECK_EQ(src_layout.ndim(), src.size());
if (src_layout == dst_layout) {
return src;
} else if (!src_layout.defined()) {
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
} else if (!dst_layout.defined()) {
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
}
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from "
<< src_layout << " to " << dst_layout;
std::vector<IndexExpr> dst(dst_layout.ndim());
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_dim = src_layout[i];
if (Layout::IsSuperdim(src_dim)) {
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_dim));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_dim));
int src_minor_pos = src_layout.Indexof(Layout::ToSubdim(src_dim));
int src_factor = src_layout.Subsizeof(src_dim);
int dst_factor = dst_layout.Subsizeof(src_dim);
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
src_dim_size *= src_factor;
}
dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
CHECK_LE(dst_factor, const_src_dim_size[0])
<< "Converting " << Array<IndexExpr>(src)
<< " from " << src_layout
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
}
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
}
}
}
return dst;
}
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout) {
std::vector<IndexExpr> ret(src.size());
for (size_t i = 0; i < src.size(); ++i) {
ret[i] = src[i];
}
return ConvertLayout(ret, src_layout, dst_layout);
}
} // namespace relay
} // namespace tvm
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Property def of nn operators. * \brief Property def of nn operators.
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
...@@ -14,7 +15,6 @@ ...@@ -14,7 +15,6 @@
#include "../type_relations.h" #include "../type_relations.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
#include "../op_common.h" #include "../op_common.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
* \file pad.cc * \file pad.cc
* \brief Implementation of operator pad * \brief Implementation of operator pad
*/ */
#include <tvm/data_layout.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn.h> #include <topi/nn.h>
#include <vector> #include <vector>
#include "../layout.h"
#include "../op_common.h" #include "../op_common.h"
namespace tvm { namespace tvm {
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
* \file pooling.cc * \file pooling.cc
* \brief Pooling operators * \brief Pooling operators
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h> #include <topi/nn/pooling.h>
#include <vector> #include <vector>
#include "../layout.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
namespace tvm { namespace tvm {
...@@ -32,14 +32,15 @@ Array<Array<Layout> > Pool2DInferCorrectLayout( ...@@ -32,14 +32,15 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
Layout raw_layout(params->layout); Layout raw_layout(params->layout);
Layout input = new_in_layouts[0]; Layout input = new_in_layouts[0];
if (input.Indexof('W') == raw_layout.Indexof('W') && if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.Indexof('H') == raw_layout.Indexof('H') && input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains('w') && !input.Contains('h')) { !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout params->layout = input.name(); // modify self to follow the input layout
} }
} }
return Array<Array<Layout> >{{params->layout}, {params->layout}}; Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
} }
template <typename AttrType> template <typename AttrType>
...@@ -59,13 +60,13 @@ bool Pool2DRel(const Array<Type>& types, ...@@ -59,13 +60,13 @@ bool Pool2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.Contains('H') && layout.Contains('W') && CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains('h') && !layout.Contains('w')) !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout << "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split"; << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.Indexof('H'); const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.Indexof('W'); const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
IndexExpr pad_h, pad_w; IndexExpr pad_h, pad_w;
if (param->padding.size() == 1) { if (param->padding.size() == 1) {
...@@ -125,6 +126,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -125,6 +126,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
const Target& target) { const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>(); const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr); CHECK(param != nullptr);
auto pool_size = param->pool_size; auto pool_size = param->pool_size;
...@@ -132,10 +134,13 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -132,10 +134,13 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
auto padding = param->padding; auto padding = param->padding;
auto ceil_mode = param->ceil_mode; auto ceil_mode = param->ceil_mode;
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.Convertible(Layout("NCHW")))
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "max_pool2d currently only supports layouts that are convertible from NCHW"; << "max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.Indexof('h'), -1) << "max_pool2d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
CHECK_EQ(layout.Indexof('w'), -1) << "max_pool2d does not support input split on width"; << "max_pool2d does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)" << "Pool2D only support 4-D input (e.g., NCHW)"
...@@ -271,13 +276,13 @@ bool GlobalPool2DRel(const Array<Type>& types, ...@@ -271,13 +276,13 @@ bool GlobalPool2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.Contains('H') && layout.Contains('W') && CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains('h') && !layout.Contains('w')) !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout << "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split"; << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.Indexof('H'); const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.Indexof('W'); const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
Array<IndexExpr> oshape(dshape); Array<IndexExpr> oshape(dshape);
oshape.Set(hidx, 1); oshape.Set(hidx, 1);
oshape.Set(widx, 1); oshape.Set(widx, 1);
...@@ -293,14 +298,15 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs, ...@@ -293,14 +298,15 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
const Target& target) { const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<GlobalPool2DAttrs>(); const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.Convertible(Layout("NCHW"))) CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; << "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.Indexof('h'), -1) CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "global_avg_pool2d does not support input split on height"; << "global_avg_pool2d does not support input split on height";
CHECK_EQ(layout.Indexof('w'), -1) CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "global_avg_pool2d does not support input split on width"; << "global_avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file upsampling.cc * \file upsampling.cc
* \brief upsampling operator * \brief upsampling operator
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
...@@ -11,7 +12,6 @@ ...@@ -11,7 +12,6 @@
#include <topi/nn/upsampling.h> #include <topi/nn/upsampling.h>
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -31,18 +31,20 @@ bool UpSamplingRel(const Array<Type>& types, ...@@ -31,18 +31,20 @@ bool UpSamplingRel(const Array<Type>& types,
const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>(); const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->layout); const Layout in_layout(param->layout);
CHECK(in_layout.Convertible(kNCHW))
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "UpSampling only support input layouts that are convertible from NCHW." << "UpSampling only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW); auto oshape = layout_converter.ForwardShape(data->shape);
oshape[2] = oshape[2] * param->scale; oshape.Set(2, oshape[2] * param->scale);
oshape[3] = oshape[3] * param->scale; oshape.Set(3, oshape[3] * param->scale);
// assign output type // assign output type
reporter->Assign(types[1], reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout), TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype)); data->dtype));
return true; return true;
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/data_layout.h>
#include <topi/transform.h> #include <topi/transform.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/broadcast.h> #include <topi/broadcast.h>
...@@ -16,7 +17,6 @@ ...@@ -16,7 +17,6 @@
#include "../op_common.h" #include "../op_common.h"
#include "../../../arithmetic/compute_expr.h" #include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -218,7 +218,7 @@ Array<Array<Layout>> ConcatenateLayout( ...@@ -218,7 +218,7 @@ Array<Array<Layout>> ConcatenateLayout(
Layout ret; Layout ret;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated. if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
Layout::LayoutDim concate_dim = old_in_layouts[0][axis]; const auto& concate_dim = old_in_layouts[0][axis];
for (size_t i = 0; i < new_in_layouts.size(); ++i) { for (size_t i = 0; i < new_in_layouts.size(); ++i) {
if (new_in_layouts[i].ndim() > axis && if (new_in_layouts[i].ndim() > axis &&
new_in_layouts[i][axis] == concate_dim) { new_in_layouts[i][axis] == concate_dim) {
...@@ -234,7 +234,7 @@ Array<Array<Layout>> ConcatenateLayout( ...@@ -234,7 +234,7 @@ Array<Array<Layout>> ConcatenateLayout(
} }
} }
if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) { if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}}; return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
} }
} }
...@@ -1682,46 +1682,10 @@ Array<Tensor> LayoutTransformCompute(const Attrs& attrs, ...@@ -1682,46 +1682,10 @@ Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
const Target& target) { const Target& target) {
const LayoutTransformAttrs *param = attrs.as<LayoutTransformAttrs>(); const auto* param = attrs.as<LayoutTransformAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
return Array<Tensor>{
Layout src_layout(param->src_layout); topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
Layout dst_layout(param->dst_layout);
if (src_layout.Equals(dst_layout)) {
return Array<Tensor>{ inputs[0] };
}
CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from " << param->src_layout << " to " << param->dst_layout;
const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout);
return Array<Tensor> {
topi::layout_transform(inputs[0], out_shape, [&](const Array<tvm::Var>& dst_indices) {
std::vector<tvm::Expr> dst_to_src_indices;
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_axis = src_layout[i];
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.Subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.Subsizeof(src_axis));
tvm::Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::IsSuperdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::IsSubdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<tvm::Expr>(dst_to_src_indices);
})
}; };
} }
...@@ -1738,10 +1702,12 @@ bool LayoutTransformRel(const Array<Type>& types, ...@@ -1738,10 +1702,12 @@ bool LayoutTransformRel(const Array<Type>& types,
CHECK(src_layout.defined() && dst_layout.defined()) CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout"; << "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
auto layout_converter = BijectiveLayoutNode::make(src_layout, dst_layout);
CHECK(layout_converter.defined())
<< "cannot convert from " << params->src_layout << " to " << params->dst_layout; << "cannot convert from " << params->src_layout << " to " << params->dst_layout;
const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout); const auto& out_shape = layout_converter.ForwardShape(data->shape);
reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype)); reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
return true; return true;
} }
......
...@@ -26,7 +26,7 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { ...@@ -26,7 +26,7 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
if (src_layout.Equals(dst_layout)) { return raw; } if (src_layout.Equals(dst_layout)) { return raw; }
CHECK(src_layout.defined() && dst_layout.defined()) CHECK(src_layout.defined() && dst_layout.defined())
<< "Cannot insert layout transform because there are undefined layouts"; << "Cannot insert layout transform because there are undefined layouts";
CHECK(src_layout.Convertible(dst_layout)) CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined())
<< "Cannot insert layout transform because there are inconvertible layouts: " << "Cannot insert layout transform because there are inconvertible layouts: "
<< src_layout << " v.s. " << dst_layout; << src_layout << " v.s. " << dst_layout;
static auto &transform_op = Op::Get("layout_transform"); static auto &transform_op = Op::Get("layout_transform");
......
...@@ -9,10 +9,9 @@ ...@@ -9,10 +9,9 @@
#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ #ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ #define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#include <tvm/data_layout.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include "../op/layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -78,7 +77,7 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs, ...@@ -78,7 +77,7 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx, layouts.Set(undef_idx,
layouts[defined_idx].Sublayout( layouts[defined_idx].SubLayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size())); old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}}; return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
...@@ -90,21 +89,22 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs, ...@@ -90,21 +89,22 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
} }
} else { } else {
// try to broadcast the tensors to the larger dimension // try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1; int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
int small_idx = 1 - large_idx; int small_idx = 1 - large_idx;
Layout ret = layouts[large_idx]; Layout ret = layouts[large_idx];
// extract common part // extract common part
size_t i = layouts[large_idx].ndim(); size_t i = layouts[large_idx].ndim();
for (; i != 0; --i) { for (; i != 0; --i) {
auto dim = layouts[large_idx][i-1]; const auto& axis = layouts[large_idx][i-1];
if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) { if (!layouts[small_idx].Contains(axis.ToPrimal())) {
break; break;
} }
} }
Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i); Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i);
if (!layouts[small_idx].Convertible(common_part)) { // fail if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) {
// not convertible
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}}; return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
} }
......
...@@ -91,8 +91,10 @@ class BranchGroupFinder : private ExprVisitor { ...@@ -91,8 +91,10 @@ class BranchGroupFinder : private ExprVisitor {
CHECK(attrs_b); CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>(); const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>(); const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->kernel_layout, kOIHW); const auto shape_a = BijectiveLayoutNode::make(
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->kernel_layout, kOIHW); Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
const auto shape_b = BijectiveLayoutNode::make(
Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
......
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
* \brief Fold axis scaling into weights of * \brief Fold axis scaling into weights of
* conv/dense operators. * conv/dense operators.
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "pattern_util.h" #include "pattern_util.h"
#include "pass_util.h" #include "pass_util.h"
#include "../op/layout.h"
namespace tvm { namespace tvm {
...@@ -435,8 +435,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) { ...@@ -435,8 +435,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C'); int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = data_layout.Indexof('c'); int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>(); Message none = NullValue<Message>();
...@@ -449,7 +449,7 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) { ...@@ -449,7 +449,7 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.Indexof('i') < 0 && if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis}; data_axes = {c_big_axis};
...@@ -473,15 +473,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call, ...@@ -473,15 +473,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C'); int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(sdata->axes.size() == 1 && CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value); c_big_axis == sdata->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O'); int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int big_ic_axis = kernel_layout.Indexof('I'); int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
...@@ -857,8 +857,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) ...@@ -857,8 +857,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
CHECK(param != nullptr); CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C'); int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = out_layout.Indexof('c'); int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
...@@ -869,8 +869,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) ...@@ -869,8 +869,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.Indexof('o') < 0 && if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
kernel_layout.Indexof('i') < 0 && kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
return MessageNode::make({c_big_axis}, false); return MessageNode::make({c_big_axis}, false);
...@@ -891,16 +891,16 @@ Expr Conv2DBackwardTransform(const Call& call, ...@@ -891,16 +891,16 @@ Expr Conv2DBackwardTransform(const Call& call,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C'); int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('o'), -1); CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(message->axes.size() == 1 && CHECK(message->axes.size() == 1 &&
c_big_axis == message->axes[0]->value); c_big_axis == message->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O'); int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d); CHECK(param->groups == 1 || is_depthwise_conv2d);
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "../op/layout.h" #include <tvm/data_layout.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -51,8 +51,8 @@ int64_t ConvMacCount(const Call& call_node) { ...@@ -51,8 +51,8 @@ int64_t ConvMacCount(const Call& call_node) {
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>(); const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape; Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_attr->data_layout; std::string data_layout = conv_2d_attr->data_layout;
int32_t C_ind = Layout(data_layout).Indexof('C'); int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
int32_t c_ind = Layout(data_layout).Indexof('c'); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
CHECK(C_ind != -1) CHECK(C_ind != -1)
<< "There is no input channel dimension."; << "There is no input channel dimension.";
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value); int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
......
...@@ -8,13 +8,13 @@ ...@@ -8,13 +8,13 @@
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_ #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_ #define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <string> #include <string>
#include "../op/layout.h"
namespace tvm { namespace tvm {
...@@ -155,9 +155,8 @@ inline bool IsDepthwiseConv2D(const Call& call, ...@@ -155,9 +155,8 @@ inline bool IsDepthwiseConv2D(const Call& call,
const Conv2DAttrs* param, const Conv2DAttrs* param,
const Layout& kernel_layout) { const Layout& kernel_layout) {
static const Layout kOIHW("OIHW"); static const Layout kOIHW("OIHW");
auto wshape = ConvertLayout( const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
call->args[1]->type_as<TensorTypeNode>()->shape, auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
kernel_layout, kOIHW);
return is_const_int(wshape[0], param->groups) && return is_const_int(wshape[0], param->groups) &&
is_const_int(wshape[1], 1); is_const_int(wshape[1], 1);
} }
......
"""Test layout and bijective-layout node"""
import tvm
from topi.util import get_const_tuple
def test_layout():
layout = tvm.layout("NCHW16c")
assert layout is not None
assert isinstance(layout, tvm.tensor.Layout)
assert layout.factor_of("c") == 16
assert layout.factor_of("C") == 16
assert layout.factor_of("N") == -1
assert layout.index_of("N") == 0
assert layout.index_of("C") == 1
assert layout.index_of("H") == 2
assert layout.index_of("W") == 3
assert layout.index_of("c") == 4
assert layout.index_of("O") == -1
assert "N" in layout
assert "C" in layout
assert "H" in layout
assert "W" in layout
assert "c" in layout
assert "O" not in layout
assert layout[0] == "N"
assert layout[1] == "C"
assert layout[2] == "H"
assert layout[3] == "W"
assert layout[4] == "c"
assert layout[-1] == "c"
def test_bilayout_convertible():
# not convertible
assert tvm.bijective_layout("NCHW", "ABCD") is None
# convertible
assert tvm.bijective_layout("NCHW", "NCHW16c") is not None
def test_bilayout_shape():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
assert isinstance(bilayout, tvm.tensor.BijectiveLayout)
dst_shape = bilayout.forward_shape((1, 32, 7, 7))
assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
src_shape = bilayout.backward_shape(dst_shape)
assert get_const_tuple(src_shape) == (1, 32, 7, 7)
def test_bilayout_index():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
dst_index = bilayout.forward_index([0, 18, 6, 6])
assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2)
src_index = bilayout.backward_index([0, 1, 6, 6, 2])
assert get_const_tuple(src_index) == (0, 18, 6, 6)
if __name__ == "__main__":
test_layout()
test_bilayout_convertible()
test_bilayout_shape()
test_bilayout_index()
...@@ -450,28 +450,5 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, ...@@ -450,28 +450,5 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
return tvm::compute(output_shape, l, name, tag); return tvm::compute(output_shape, l, name, tag);
} }
using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
/*!
* \brief Transform the layout according to the mapping function \p to_src_indices.
* \param src the source input.
* \param dst_shape the output shape.
* \param to_src_indices the mapping function from input index to output index.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape \p dst_shape.
*/
inline Tensor layout_transform(const Tensor& src,
const Array<Expr>& dst_shape,
const FLayoutIndicesTransform& to_src_indices,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
auto src_shape = src->shape;
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
return src(to_src_indices(dst_indices));
}, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_NN_H_ #endif // TOPI_NN_H_
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "topi/detail/ravel_unravel.h" #include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h" #include "topi/detail/constant_utils.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
#include "tvm/data_layout.h"
namespace topi { namespace topi {
using namespace tvm; using namespace tvm;
...@@ -882,5 +883,43 @@ inline Tensor arange(const Expr start, ...@@ -882,5 +883,43 @@ inline Tensor arange(const Expr start,
}, name, tag); }, name, tag);
} }
/*!
* \brief Transform the layout according to \p src_layout and \p dst_layout
* \param src the source input.
* \param src_layout the source layout.
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor layout_transform(const Tensor& src,
const std::string& src_layout,
const std::string& dst_layout,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
Layout src_layout_struct = LayoutNode::make(src_layout);
Layout dst_layout_struct = LayoutNode::make(dst_layout);
if (src_layout_struct.Equals(dst_layout_struct)) {
return src;
}
CHECK(src_layout_struct.defined() && dst_layout_struct.defined())
<< "cannot convert from/to undefined layout";
auto layout_converter = BijectiveLayoutNode::make(src_layout_struct, dst_layout_struct);
CHECK(layout_converter.defined())
<< "cannot convert from " << src_layout << " to " << dst_layout;
Array<Expr> dst_shape = layout_converter.ForwardShape(src->shape);
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
Array<Expr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<Expr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
return src(src_indices);
}, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_TRANSFORM_H_ #endif // TOPI_TRANSFORM_H_
...@@ -318,3 +318,20 @@ def arange(start, stop=None, step=1, dtype="float32"): ...@@ -318,3 +318,20 @@ def arange(start, stop=None, step=1, dtype="float32"):
stop = start stop = start
start = 0 start = 0
return cpp.arange(start, stop, step, dtype) return cpp.arange(start, stop, step, dtype)
def layout_transform(array, src_layout, dst_layout):
"""Transform the layout according to src_layout and dst_layout
Parameters
----------
array : tvm.Tensor
The source array.
src_layout : str
the source layout.
dst_layout : str
the destination layout.
"""
return cpp.layout_transform(array, src_layout, dst_layout)
...@@ -272,6 +272,11 @@ TVM_REGISTER_GLOBAL("topi.split") ...@@ -272,6 +272,11 @@ TVM_REGISTER_GLOBAL("topi.split")
} }
}); });
TVM_REGISTER_GLOBAL("topi.layout_transform")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = layout_transform(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.take") TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) { if (args.size() == 2) {
......
...@@ -449,6 +449,34 @@ def test_arange(): ...@@ -449,6 +449,34 @@ def test_arange():
verify_arange(20, 1, -1.5) verify_arange(20, 1, -1.5)
def test_layout_transform():
in_shape = (1, 32, 8, 8)
A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
B = topi.layout_transform(A, "NCHW", "NCHW16c")
input = np.random.uniform(size=in_shape).astype(A.dtype)
output = np.transpose(input, axes=(0, 2, 3, 1))
output = np.reshape(output, newshape=(1, 8, 8, 2, 16))
output = np.transpose(output, axes=(0, 3, 1, 2, 4))
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
tvm_input = tvm.nd.array(input, ctx)
tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype)
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
f = tvm.build(s, [A, B], device, name="layout_transform")
f(tvm_input, tvm_output)
tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
for backend in get_all_backend():
check_device(backend)
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice() test_strided_slice()
test_concatenate() test_concatenate()
...@@ -462,3 +490,4 @@ if __name__ == "__main__": ...@@ -462,3 +490,4 @@ if __name__ == "__main__":
test_take() test_take()
test_gather_nd() test_gather_nd()
test_arange() test_arange()
test_layout_transform()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment