Commit ee79703c by Yizhi Liu Committed by Lianmin Zheng

[Lang] Layout in TVM node system (#2509)

* move layout.h & layout.cc from relay to tvm

* change ConvertLayout in relay to bijectiveLayout->Forward/backward

* add first test case

* add LayoutAxis

* add LayoutAxis struct and compiles

* simplify BijectiveLayout rule consturct

* polish func name for Layout, move impl to .cc, remove Layout::defined(), add defined() checker

* partially add layout py support

* add layout test cases

* add doc for tvm.layout & tvm.bijective_layout

* fix lint

* fix lint

* fix layout name generation bug

* fix layout typo

* address comments and add topi.layout_transform

* layout.h->data_layout.h, test_lang_layout.py->test_lang_data_layout.py
parent ddc31fd4
......@@ -68,6 +68,7 @@ List of operators
topi.greater_equal
topi.less_equal
topi.arange
topi.layout_transform
topi.image.resize
......@@ -125,6 +126,7 @@ topi
.. autofunction:: topi.greater
.. autofunction:: topi.less
.. autofunction:: topi.arange
.. autofunction:: topi.layout_transform
topi.nn
~~~~~~~
......
......@@ -674,42 +674,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
const Array<Tensor>& inputs,
const Array<Tensor>& outputs) {
const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed);
Layout src_layout(param.src_layout);
Layout dst_layout(param.dst_layout);
if (src_layout == dst_layout) {
return Array<Tensor>{ inputs[0] };
} else if (!src_layout.defined() || !dst_layout.defined()) {
LOG(FATAL) << "cannot convert from/to undefined layout";
}
CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " << param.src_layout
<< " to " << param.dst_layout;
return Array<Tensor> {
topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array<Var>& dst_indices) {
std::vector<Expr> dst_to_src_indices;
for (Layout::LayoutDim src_axis : src_layout) {
int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis));
int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis));
Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::is_superdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::is_subdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<Expr>(dst_to_src_indices);
})
return Array<Tensor>{
topi::layout_transform(inputs[0], param.src_layout, param.dst_layout)
};
})
.set_support_level(1);
......
......@@ -515,7 +515,7 @@ def decl_buffer(shape,
scope="",
data_alignment=-1,
offset_factor=0):
"""Decleare a new symbolic buffer.
"""Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout.
......@@ -587,6 +587,49 @@ def decl_buffer(shape,
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor)
def layout(layout_str):
"""Create a layout node from a string.
Parameters
----------
layout_str : str
A layout representation is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).
Returns
-------
layout : Layout
The created layout
"""
return _api_internal._Layout(layout_str)
def bijective_layout(src_layout, dst_layout):
"""Create a bijective layout mapping.
Parameters
----------
src_layout : str or Layout
source layout.
dst_layout : str or Layout
destination layout.
Returns
-------
bijective_layout : BijectiveLayout
The created bijective layout
"""
if isinstance(src_layout, str):
src_layout = layout(src_layout)
if isinstance(dst_layout, str):
dst_layout = layout(dst_layout)
return _api_internal._BijectiveLayout(src_layout, dst_layout)
def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar
......
......@@ -185,3 +185,142 @@ class HybridOp(Operation):
def axis(self):
"""Represent axis of IterVar, also defined when it is a HybridOp"""
return self.__getattr__("axis")
@register_node
class Layout(NodeBase):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
Do not construct directly, use :any:`layout` instead.
See the documentation of :any:`layout` for more details.
See Also
--------
layout : Declare a layout
"""
def __str__(self):
return self.name
def __repr__(self):
return "Layout(" + self.name + ")"
def __len__(self):
return _api_internal._LayoutNdim(self)
def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Layout index out of range")
return _api_internal._LayoutGetItem(self, index)
def index_of(self, axis):
"""Get the index of an axis
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
index : int
The index of the axis, -1 if not found.
"""
return _api_internal._LayoutIndexOf(self, axis)
def factor_of(self, axis):
"""Get the factor size of the subordinate axis.
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
factor : int
the size of the subordinate-axis of axis (if axis is a primal-axis),
or the size of axis itself (if axis is a subordinate-axis).
Return -1 if axis is not in the layout.
"""
return _api_internal._LayoutFactorOf(self, axis)
@register_node
class BijectiveLayout(NodeBase):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
Do not construct directly, use :any:`bijective_layout` instead.
See the documentation of :any:`bijective_layout` for more details.
See Also
--------
bijective_layout : Declare a bijective layout converter
"""
def forward_index(self, index):
"""Given the indices of the src-layout, infer the dst index.
Parameters
----------
index: Array of Expr
The indices in src-layout.
Returns
-------
dst_index: Array of Expr
The inferred indices in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardIndex(self, index)
def backward_index(self, index):
"""Given the indices of the dst-layout, infer the src index.
Parameters
----------
index: Array of Expr
The indices in dst-layout.
Returns
-------
src_index: Array of Expr
The inferred indices in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardIndex(self, index)
def forward_shape(self, shape):
"""Given the shape of the src-layout, infer the dst shape.
Parameters
----------
shape: Array of Expr
The shape in src-layout.
Returns
-------
dst_shape: Array of Expr
The inferred shape in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardShape(self, shape)
def backward_shape(self, shape):
"""Given the shape of the dst-layout, infer the src shape.
Parameters
----------
shape: Array of Expr
The shape in dst-layout.
Returns
-------
src_shape: Array of Expr
The inferred shape in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardShape(self, shape)
......@@ -11,6 +11,7 @@
#include <tvm/schedule.h>
#include <tvm/api_registry.h>
#include <tvm/build_module.h>
#include <tvm/data_layout.h>
namespace tvm {
......@@ -224,6 +225,63 @@ TVM_REGISTER_API("_BufferVStore")
.vstore(args[1], args[2]);
});
TVM_REGISTER_API("_Layout")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LayoutNode::make(args[0]);
});
TVM_REGISTER_API("_LayoutIndexOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Layout()
.IndexOf(LayoutAxis::make(args[1]));
});
TVM_REGISTER_API("_LayoutFactorOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Layout()
.FactorOf(LayoutAxis::make(args[1]));
});
TVM_REGISTER_API("_LayoutNdim")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(args[0].operator Layout().ndim());
});
TVM_REGISTER_API("_LayoutGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
const LayoutAxis& axis = args[0].operator Layout()[args[1]];
*ret = axis.name();
});
TVM_REGISTER_API("_BijectiveLayout")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BijectiveLayoutNode::make(args[0], args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.ForwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.BackwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.ForwardShape(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.BackwardShape(args[1]);
});
TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0],
......
......@@ -4,13 +4,13 @@
* \brief Property def of nn operators.
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/debug.h>
#include <topi/elemwise.h>
#include <vector>
#include "./type_relations.h"
#include "./op_common.h"
#include "./layout.h"
namespace tvm {
namespace relay {
......
......@@ -3,11 +3,11 @@
* \file resize.cc
* \brief Image operators
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h>
#include <topi/elemwise.h>
#include <topi/image/resize.h>
#include "../layout.h"
#include "../op_common.h"
namespace tvm {
......@@ -28,17 +28,18 @@ bool ResizeRel(const Array<Type>& types,
const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
CHECK(in_layout.Convertible(kNCHW))
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW);
oshape[2] = param->size[0];
oshape[3] = param->size[1];
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, param->size[0]);
oshape.Set(3, param->size[1]);
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout),
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/relay/op/layout.cc
* \brief Layout expression.
*/
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(LayoutNode);
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout) {
CHECK_EQ(src_layout.ndim(), src.size());
if (src_layout == dst_layout) {
return src;
} else if (!src_layout.defined()) {
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
} else if (!dst_layout.defined()) {
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
}
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from "
<< src_layout << " to " << dst_layout;
std::vector<IndexExpr> dst(dst_layout.ndim());
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_dim = src_layout[i];
if (Layout::IsSuperdim(src_dim)) {
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_dim));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_dim));
int src_minor_pos = src_layout.Indexof(Layout::ToSubdim(src_dim));
int src_factor = src_layout.Subsizeof(src_dim);
int dst_factor = dst_layout.Subsizeof(src_dim);
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
src_dim_size *= src_factor;
}
dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
CHECK_LE(dst_factor, const_src_dim_size[0])
<< "Converting " << Array<IndexExpr>(src)
<< " from " << src_layout
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
}
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
}
}
}
return dst;
}
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout) {
std::vector<IndexExpr> ret(src.size());
for (size_t i = 0; i < src.size(); ++i) {
ret[i] = src[i];
}
return ConvertLayout(ret, src_layout, dst_layout);
}
} // namespace relay
} // namespace tvm
......@@ -4,6 +4,7 @@
* \brief Property def of nn operators.
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/image.h>
......@@ -14,7 +15,6 @@
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
#include "../op_common.h"
#include "../layout.h"
namespace tvm {
namespace relay {
......
......@@ -3,12 +3,12 @@
* \file pad.cc
* \brief Implementation of operator pad
*/
#include <tvm/data_layout.h>
#include <tvm/ir_operator.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <topi/nn.h>
#include <vector>
#include "../layout.h"
#include "../op_common.h"
namespace tvm {
......
......@@ -3,12 +3,12 @@
* \file pooling.cc
* \brief Pooling operators
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h>
#include <vector>
#include "../layout.h"
#include "../../pass/alter_op_layout.h"
namespace tvm {
......@@ -32,14 +32,15 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.Indexof('W') == raw_layout.Indexof('W') &&
input.Indexof('H') == raw_layout.Indexof('H') &&
!input.Contains('w') && !input.Contains('h')) {
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
}
}
return Array<Array<Layout> >{{params->layout}, {params->layout}};
Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}
template <typename AttrType>
......@@ -59,13 +60,13 @@ bool Pool2DRel(const Array<Type>& types,
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Contains('H') && layout.Contains('W') &&
!layout.Contains('h') && !layout.Contains('w'))
CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.Indexof('H');
const auto widx = layout.Indexof('W');
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
IndexExpr pad_h, pad_w;
if (param->padding.size() == 1) {
......@@ -125,6 +126,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
auto pool_size = param->pool_size;
......@@ -132,10 +134,13 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
auto padding = param->padding;
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);
CHECK(layout.Convertible(Layout("NCHW")))
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.Indexof('h'), -1) << "max_pool2d does not support input split on height";
CHECK_EQ(layout.Indexof('w'), -1) << "max_pool2d does not support input split on width";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "max_pool2d does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
......@@ -271,13 +276,13 @@ bool GlobalPool2DRel(const Array<Type>& types,
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Contains('H') && layout.Contains('W') &&
!layout.Contains('h') && !layout.Contains('w'))
CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.Indexof('H');
const auto widx = layout.Indexof('W');
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
Array<IndexExpr> oshape(dshape);
oshape.Set(hidx, 1);
oshape.Set(widx, 1);
......@@ -293,14 +298,15 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Convertible(Layout("NCHW")))
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.Indexof('h'), -1)
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "global_avg_pool2d does not support input split on height";
CHECK_EQ(layout.Indexof('w'), -1)
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "global_avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
......
......@@ -3,6 +3,7 @@
* \file upsampling.cc
* \brief upsampling operator
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op_attr_types.h>
......@@ -11,7 +12,6 @@
#include <topi/nn/upsampling.h>
#include <vector>
#include "../op_common.h"
#include "../layout.h"
namespace tvm {
namespace relay {
......@@ -31,18 +31,20 @@ bool UpSamplingRel(const Array<Type>& types,
const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
CHECK(in_layout.Convertible(kNCHW))
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "UpSampling only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW);
auto oshape = layout_converter.ForwardShape(data->shape);
oshape[2] = oshape[2] * param->scale;
oshape[3] = oshape[3] * param->scale;
oshape.Set(2, oshape[2] * param->scale);
oshape.Set(3, oshape[3] * param->scale);
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout),
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
......
......@@ -7,6 +7,7 @@
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h>
#include <tvm/ir.h>
#include <tvm/data_layout.h>
#include <topi/transform.h>
#include <topi/elemwise.h>
#include <topi/broadcast.h>
......@@ -16,7 +17,6 @@
#include "../op_common.h"
#include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h"
#include "../layout.h"
namespace tvm {
namespace relay {
......@@ -218,7 +218,7 @@ Array<Array<Layout>> ConcatenateLayout(
Layout ret;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
Layout::LayoutDim concate_dim = old_in_layouts[0][axis];
const auto& concate_dim = old_in_layouts[0][axis];
for (size_t i = 0; i < new_in_layouts.size(); ++i) {
if (new_in_layouts[i].ndim() > axis &&
new_in_layouts[i][axis] == concate_dim) {
......@@ -234,7 +234,7 @@ Array<Array<Layout>> ConcatenateLayout(
}
}
if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) {
if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
}
......@@ -1682,46 +1682,10 @@ Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const LayoutTransformAttrs *param = attrs.as<LayoutTransformAttrs>();
const auto* param = attrs.as<LayoutTransformAttrs>();
CHECK(param != nullptr);
Layout src_layout(param->src_layout);
Layout dst_layout(param->dst_layout);
if (src_layout.Equals(dst_layout)) {
return Array<Tensor>{ inputs[0] };
}
CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from " << param->src_layout << " to " << param->dst_layout;
const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout);
return Array<Tensor> {
topi::layout_transform(inputs[0], out_shape, [&](const Array<tvm::Var>& dst_indices) {
std::vector<tvm::Expr> dst_to_src_indices;
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_axis = src_layout[i];
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.Subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.Subsizeof(src_axis));
tvm::Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::IsSuperdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::IsSubdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<tvm::Expr>(dst_to_src_indices);
})
return Array<Tensor>{
topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
};
}
......@@ -1738,10 +1702,12 @@ bool LayoutTransformRel(const Array<Type>& types,
CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
auto layout_converter = BijectiveLayoutNode::make(src_layout, dst_layout);
CHECK(layout_converter.defined())
<< "cannot convert from " << params->src_layout << " to " << params->dst_layout;
const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout);
const auto& out_shape = layout_converter.ForwardShape(data->shape);
reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
return true;
}
......
......@@ -26,7 +26,7 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
if (src_layout.Equals(dst_layout)) { return raw; }
CHECK(src_layout.defined() && dst_layout.defined())
<< "Cannot insert layout transform because there are undefined layouts";
CHECK(src_layout.Convertible(dst_layout))
CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined())
<< "Cannot insert layout transform because there are inconvertible layouts: "
<< src_layout << " v.s. " << dst_layout;
static auto &transform_op = Op::Get("layout_transform");
......
......@@ -9,10 +9,9 @@
#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#include <tvm/data_layout.h>
#include <tvm/relay/expr.h>
#include "../op/layout.h"
namespace tvm {
namespace relay {
......@@ -78,9 +77,9 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx,
layouts[defined_idx].Sublayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
layouts[defined_idx].SubLayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
} else {
// only know the tensor with smaller dimensions,
......@@ -90,21 +89,22 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
}
} else {
// try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1;
int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
int small_idx = 1 - large_idx;
Layout ret = layouts[large_idx];
// extract common part
size_t i = layouts[large_idx].ndim();
for (; i != 0; --i) {
auto dim = layouts[large_idx][i-1];
if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) {
const auto& axis = layouts[large_idx][i-1];
if (!layouts[small_idx].Contains(axis.ToPrimal())) {
break;
}
}
Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i);
if (!layouts[small_idx].Convertible(common_part)) { // fail
Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i);
if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) {
// not convertible
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
......
......@@ -91,8 +91,10 @@ class BranchGroupFinder : private ExprVisitor {
CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->kernel_layout, kOIHW);
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->kernel_layout, kOIHW);
const auto shape_a = BijectiveLayoutNode::make(
Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
const auto shape_b = BijectiveLayoutNode::make(
Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
......
......@@ -6,12 +6,12 @@
* \brief Fold axis scaling into weights of
* conv/dense operators.
*/
#include <tvm/data_layout.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "pass_util.h"
#include "../op/layout.h"
namespace tvm {
......@@ -435,8 +435,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C');
int c_small_axis = data_layout.Indexof('c');
int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
CHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>();
......@@ -449,7 +449,7 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.Indexof('i') < 0 &&
if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
......@@ -473,15 +473,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C');
int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O');
int big_ic_axis = kernel_layout.Indexof('I');
int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
......@@ -857,8 +857,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C');
int c_small_axis = out_layout.Indexof('c');
int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
......@@ -869,8 +869,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.Indexof('o') < 0 &&
kernel_layout.Indexof('i') < 0 &&
if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return MessageNode::make({c_big_axis}, false);
......@@ -891,16 +891,16 @@ Expr Conv2DBackwardTransform(const Call& call,
CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C');
int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('o'), -1);
CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(message->axes.size() == 1 &&
c_big_axis == message->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O');
int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
......
......@@ -11,7 +11,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "../op/layout.h"
#include <tvm/data_layout.h>
namespace tvm {
namespace relay {
......@@ -51,8 +51,8 @@ int64_t ConvMacCount(const Call& call_node) {
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_attr->data_layout;
int32_t C_ind = Layout(data_layout).Indexof('C');
int32_t c_ind = Layout(data_layout).Indexof('c');
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
CHECK(C_ind != -1)
<< "There is no input channel dimension.";
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
......
......@@ -8,13 +8,13 @@
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/nn.h>
#include <string>
#include "../op/layout.h"
namespace tvm {
......@@ -155,9 +155,8 @@ inline bool IsDepthwiseConv2D(const Call& call,
const Conv2DAttrs* param,
const Layout& kernel_layout) {
static const Layout kOIHW("OIHW");
auto wshape = ConvertLayout(
call->args[1]->type_as<TensorTypeNode>()->shape,
kernel_layout, kOIHW);
const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
return is_const_int(wshape[0], param->groups) &&
is_const_int(wshape[1], 1);
}
......
"""Test layout and bijective-layout node"""
import tvm
from topi.util import get_const_tuple
def test_layout():
layout = tvm.layout("NCHW16c")
assert layout is not None
assert isinstance(layout, tvm.tensor.Layout)
assert layout.factor_of("c") == 16
assert layout.factor_of("C") == 16
assert layout.factor_of("N") == -1
assert layout.index_of("N") == 0
assert layout.index_of("C") == 1
assert layout.index_of("H") == 2
assert layout.index_of("W") == 3
assert layout.index_of("c") == 4
assert layout.index_of("O") == -1
assert "N" in layout
assert "C" in layout
assert "H" in layout
assert "W" in layout
assert "c" in layout
assert "O" not in layout
assert layout[0] == "N"
assert layout[1] == "C"
assert layout[2] == "H"
assert layout[3] == "W"
assert layout[4] == "c"
assert layout[-1] == "c"
def test_bilayout_convertible():
# not convertible
assert tvm.bijective_layout("NCHW", "ABCD") is None
# convertible
assert tvm.bijective_layout("NCHW", "NCHW16c") is not None
def test_bilayout_shape():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
assert isinstance(bilayout, tvm.tensor.BijectiveLayout)
dst_shape = bilayout.forward_shape((1, 32, 7, 7))
assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
src_shape = bilayout.backward_shape(dst_shape)
assert get_const_tuple(src_shape) == (1, 32, 7, 7)
def test_bilayout_index():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
dst_index = bilayout.forward_index([0, 18, 6, 6])
assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2)
src_index = bilayout.backward_index([0, 1, 6, 6, 2])
assert get_const_tuple(src_index) == (0, 18, 6, 6)
if __name__ == "__main__":
test_layout()
test_bilayout_convertible()
test_bilayout_shape()
test_bilayout_index()
......@@ -450,28 +450,5 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
return tvm::compute(output_shape, l, name, tag);
}
using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
/*!
* \brief Transform the layout according to the mapping function \p to_src_indices.
* \param src the source input.
* \param dst_shape the output shape.
* \param to_src_indices the mapping function from input index to output index.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape \p dst_shape.
*/
inline Tensor layout_transform(const Tensor& src,
const Array<Expr>& dst_shape,
const FLayoutIndicesTransform& to_src_indices,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
auto src_shape = src->shape;
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
return src(to_src_indices(dst_indices));
}, name, tag);
}
} // namespace topi
#endif // TOPI_NN_H_
......@@ -16,6 +16,7 @@
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "tvm/tvm.h"
#include "tvm/data_layout.h"
namespace topi {
using namespace tvm;
......@@ -882,5 +883,43 @@ inline Tensor arange(const Expr start,
}, name, tag);
}
/*!
* \brief Transform the layout according to \p src_layout and \p dst_layout
* \param src the source input.
* \param src_layout the source layout.
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor layout_transform(const Tensor& src,
const std::string& src_layout,
const std::string& dst_layout,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
Layout src_layout_struct = LayoutNode::make(src_layout);
Layout dst_layout_struct = LayoutNode::make(dst_layout);
if (src_layout_struct.Equals(dst_layout_struct)) {
return src;
}
CHECK(src_layout_struct.defined() && dst_layout_struct.defined())
<< "cannot convert from/to undefined layout";
auto layout_converter = BijectiveLayoutNode::make(src_layout_struct, dst_layout_struct);
CHECK(layout_converter.defined())
<< "cannot convert from " << src_layout << " to " << dst_layout;
Array<Expr> dst_shape = layout_converter.ForwardShape(src->shape);
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
Array<Expr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<Expr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
return src(src_indices);
}, name, tag);
}
} // namespace topi
#endif // TOPI_TRANSFORM_H_
......@@ -318,3 +318,20 @@ def arange(start, stop=None, step=1, dtype="float32"):
stop = start
start = 0
return cpp.arange(start, stop, step, dtype)
def layout_transform(array, src_layout, dst_layout):
"""Transform the layout according to src_layout and dst_layout
Parameters
----------
array : tvm.Tensor
The source array.
src_layout : str
the source layout.
dst_layout : str
the destination layout.
"""
return cpp.layout_transform(array, src_layout, dst_layout)
......@@ -272,6 +272,11 @@ TVM_REGISTER_GLOBAL("topi.split")
}
});
TVM_REGISTER_GLOBAL("topi.layout_transform")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = layout_transform(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
......
......@@ -449,6 +449,34 @@ def test_arange():
verify_arange(20, 1, -1.5)
def test_layout_transform():
in_shape = (1, 32, 8, 8)
A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
B = topi.layout_transform(A, "NCHW", "NCHW16c")
input = np.random.uniform(size=in_shape).astype(A.dtype)
output = np.transpose(input, axes=(0, 2, 3, 1))
output = np.reshape(output, newshape=(1, 8, 8, 2, 16))
output = np.transpose(output, axes=(0, 3, 1, 2, 4))
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
tvm_input = tvm.nd.array(input, ctx)
tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype)
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
f = tvm.build(s, [A, B], device, name="layout_transform")
f(tvm_input, tvm_output)
tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
for backend in get_all_backend():
check_device(backend)
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
......@@ -462,3 +490,4 @@ if __name__ == "__main__":
test_take()
test_gather_nd()
test_arange()
test_layout_transform()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment