Commit ee79703c by Yizhi Liu Committed by Lianmin Zheng

[Lang] Layout in TVM node system (#2509)

* move layout.h & layout.cc from relay to tvm

* change ConvertLayout in relay to bijectiveLayout->Forward/backward

* add first test case

* add LayoutAxis

* add LayoutAxis struct and compiles

* simplify BijectiveLayout rule consturct

* polish func name for Layout, move impl to .cc, remove Layout::defined(), add defined() checker

* partially add layout py support

* add layout test cases

* add doc for tvm.layout & tvm.bijective_layout

* fix lint

* fix lint

* fix layout name generation bug

* fix layout typo

* address comments and add topi.layout_transform

* layout.h->data_layout.h, test_lang_layout.py->test_lang_data_layout.py
parent ddc31fd4
...@@ -68,6 +68,7 @@ List of operators ...@@ -68,6 +68,7 @@ List of operators
topi.greater_equal topi.greater_equal
topi.less_equal topi.less_equal
topi.arange topi.arange
topi.layout_transform
topi.image.resize topi.image.resize
...@@ -125,6 +126,7 @@ topi ...@@ -125,6 +126,7 @@ topi
.. autofunction:: topi.greater .. autofunction:: topi.greater
.. autofunction:: topi.less .. autofunction:: topi.less
.. autofunction:: topi.arange .. autofunction:: topi.arange
.. autofunction:: topi.layout_transform
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/data_layout.h
* \brief Layout expression to describe the data organization of a tensor.
* And BijectiveLayout to mapping two data layouts between each other.
*/
#ifndef TVM_DATA_LAYOUT_H_
#define TVM_DATA_LAYOUT_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
#include "ir_operator.h"
namespace tvm {
class LayoutAxis {
public:
static const LayoutAxis& Get(const char name);
// Get the singleton LayoutAxis using itvar->var->name_hint
static const LayoutAxis& Get(const IterVar& itvar);
// Get the singleton LayoutAxis using name[0] (size of name must be 1).
static const LayoutAxis& make(const std::string& name);
inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
inline std::string name() const { return std::string(1, name_); }
// if current axis is primal, switch the axis to its subordinate one,
// else switch to the primal.
inline const LayoutAxis& ToDual() const {
if (name_ >= 'A' && name_ <= 'Z') {
return LayoutAxis::Get(name_ - 'A' + 'a');
} else {
return LayoutAxis::Get(name_ - 'a' + 'A');
}
}
// return the primal axis. If it is already primal, return itself.
const LayoutAxis& ToPrimal() const {
return IsPrimal() ? *this : ToDual();
}
// return the subordinate axis. If it is already subordinate, return itself.
const LayoutAxis& ToSubordinate() const {
return IsPrimal() ? ToDual() : *this;
}
inline bool operator==(const LayoutAxis& rhs) const {
return name_ == rhs.name_;
}
friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
os << l.name();
return os;
}
private:
static const LayoutAxis UPPER_CASE[];
static const LayoutAxis LOWER_CASE[];
LayoutAxis(const LayoutAxis&);
LayoutAxis& operator=(const LayoutAxis&);
explicit LayoutAxis(const char name) : name_(name) {}
const char name_;
};
class Layout;
// Internal node container Buffer
class LayoutNode : public Node {
public:
/*! \brief string representation of layout */
std::string name;
/*! \brief specify each axis of the layout,
* in which the variable name is the name of the axis.
* The IterVar's extent indicates the size of the axis,
* it is a variable for a primal axis, but a constant for a subordinate axis.
*/
Array<IterVar> axes;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("axes", &axes);
}
TVM_DLL static Layout make(const std::string& layout);
static constexpr const char* _type_key = "Layout";
TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
};
/*!
* \brief Layout is to describe how data is organized within an N-dimention tensor.
* It is composed of upper cases, lower cases and numbers,
* where upper case indicates a primal axis and
* the corresponding lower case with factor size indicates the subordinate axis.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
*/
class Layout : public NodeRef {
public:
explicit Layout(NodePtr<Node> n) : NodeRef(n) {}
/*! \brief default constructor */
Layout() = default;
explicit Layout(const Array<IterVar>& axes);
/*! \brief construct from a string */
Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param name input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
Layout(const std::string& name); // NOLINT(*)
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
const LayoutNode* operator->() const {
return static_cast<const LayoutNode*>(node_.get());
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
LayoutNode* operator->() {
return static_cast<LayoutNode*>(node_.get());
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Returns a sub-layout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
Layout SubLayout(size_t pos, size_t len) const;
/*!
* \brief Split \p axis by \p size and put the sub-axis to position \p target_pos.
* \param axis The source axis to be split. It must be a primal-axis;
* \param target_pos The target position of the newly split subordinate-axis.
* \param factor size of the sub-dimension.
* \return A newly constructed Layout object.
*/
Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const;
/*! \return number of dimensions */
inline size_t ndim() const {
if (!defined()) return 0;
return operator->()->axes.size();
}
/*! \return number of super dimensions */
inline size_t ndim_primal() const {
if (!defined()) return 0;
size_t ct = 0;
for (auto x : operator->()->axes) {
if (LayoutAxis::Get(x).IsPrimal()) {
ct++;
}
}
return ct;
}
/*!
* \brief return the index of the input axis.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param axis the input axis.
* \return the index or -1 if not found.
*/
inline int32_t IndexOf(const LayoutAxis& axis) const {
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
}
return -1;
}
/*!
* \brief Get the factor size of the subordinate axis.
* \param axis the input primal-axis or subordinate-axis.
* \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis),
* or the size of \p axis itself (if \p axis is a subordinate-axis).
* Return -1 if \p axis is not in the layout the layout is undefined.
*/
int32_t FactorOf(const LayoutAxis& axis) const;
/*!
* \brief Whether the layout contains an axis.
* \param axis axis to be checked.
* \return Whether the layout contains the axis.
*/
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
for (const IterVar var : operator->()->axes) {
if (var->var.get()->name_hint == axis.name()) {
return true;
}
}
return false;
}
const LayoutAxis& operator[](int32_t i) const {
CHECK(defined()) << "Try to access axis from an undefined layout.";
int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
const IterVar axis = operator->()->axes[index];
return LayoutAxis::Get(axis);
}
/*! \return the string description of the layout */
inline std::string name() const {
if (!defined()) return "__undef__";
return operator->()->name;
}
/*!
* \brief Whether the two layouts are equal.
* \param rhs Another layout.
* \return whether the two layouts are equal.
*/
inline bool Equals(const Layout &rhs) const {
return name() == rhs.name();
}
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name();
return os;
}
using ContainerType = LayoutNode;
};
class BijectiveLayout;
// Internal node container BijectiveLayout
class BijectiveLayoutNode : public Node {
public:
/*! \brief Describes how source axes can be mapped to the destination axes,
* e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
*/
Array<Expr> forward_rule;
/*! \brief Describes how destination axes can be mapped to the source axes */
Array<Expr> backward_rule;
/*! \brief The source layout */
Layout src_layout;
/*! \brief The destination layout */
Layout dst_layout;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("src_layout", &src_layout);
v->Visit("dst_layout", &dst_layout);
v->Visit("forward_rule", &forward_rule);
v->Visit("backward_rule", &backward_rule);
}
static constexpr const char* _type_key = "BijectiveLayout";
TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node);
TVM_DLL static BijectiveLayout make(const Layout& src_layout,
const Layout& dst_layout);
};
/*! \brief Bijective function mapping for data layout transformation.
* Given two Layout, BijectiveLayout build and store the mapping rules,
* provides API to transform N-dimention tensor from the source indices (i0, i1, …, im)
* to the destination indices (j0, j1, … jm).
*/
class BijectiveLayout : public NodeRef {
public:
BijectiveLayout() = default;
explicit BijectiveLayout(NodePtr<Node> n) : NodeRef(n) {}
// Given the source shape, infer the destination shape.
TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const;
// Given the destination shape, recover the source shape.
TVM_DLL Array<Expr> BackwardShape(const Array<Expr>& dst_shape) const;
// Given the destination indices, infer the destination indices.
TVM_DLL Array<Expr> ForwardIndex(const Array<Expr>& index) const;
// Given the destination indices, recover the source indices.
TVM_DLL Array<Expr> BackwardIndex(const Array<Expr>& dst_index) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const BijectiveLayoutNode* operator->() const;
/*! \brief specify container node */
using ContainerType = BijectiveLayoutNode;
};
inline const BijectiveLayoutNode* BijectiveLayout::operator->() const {
return static_cast<const BijectiveLayoutNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_DATA_LAYOUT_H_
...@@ -674,42 +674,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] ...@@ -674,42 +674,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& outputs) { const Array<Tensor>& outputs) {
const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed); const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed);
return Array<Tensor>{
Layout src_layout(param.src_layout); topi::layout_transform(inputs[0], param.src_layout, param.dst_layout)
Layout dst_layout(param.dst_layout);
if (src_layout == dst_layout) {
return Array<Tensor>{ inputs[0] };
} else if (!src_layout.defined() || !dst_layout.defined()) {
LOG(FATAL) << "cannot convert from/to undefined layout";
}
CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " << param.src_layout
<< " to " << param.dst_layout;
return Array<Tensor> {
topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array<Var>& dst_indices) {
std::vector<Expr> dst_to_src_indices;
for (Layout::LayoutDim src_axis : src_layout) {
int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis));
int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis));
Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::is_superdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::is_subdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<Expr>(dst_to_src_indices);
})
}; };
}) })
.set_support_level(1); .set_support_level(1);
......
...@@ -515,7 +515,7 @@ def decl_buffer(shape, ...@@ -515,7 +515,7 @@ def decl_buffer(shape,
scope="", scope="",
data_alignment=-1, data_alignment=-1,
offset_factor=0): offset_factor=0):
"""Decleare a new symbolic buffer. """Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build. Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout. This is only needed if user want to specify their own buffer layout.
...@@ -587,6 +587,49 @@ def decl_buffer(shape, ...@@ -587,6 +587,49 @@ def decl_buffer(shape,
data, dtype, shape, strides, elem_offset, name, scope, data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor) data_alignment, offset_factor)
def layout(layout_str):
"""Create a layout node from a string.
Parameters
----------
layout_str : str
A layout representation is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).
Returns
-------
layout : Layout
The created layout
"""
return _api_internal._Layout(layout_str)
def bijective_layout(src_layout, dst_layout):
"""Create a bijective layout mapping.
Parameters
----------
src_layout : str or Layout
source layout.
dst_layout : str or Layout
destination layout.
Returns
-------
bijective_layout : BijectiveLayout
The created bijective layout
"""
if isinstance(src_layout, str):
src_layout = layout(src_layout)
if isinstance(dst_layout, str):
dst_layout = layout(dst_layout)
return _api_internal._BijectiveLayout(src_layout, dst_layout)
def _IterVar(dom, name, iter_type, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar """Internal function to create IterVar
......
...@@ -185,3 +185,142 @@ class HybridOp(Operation): ...@@ -185,3 +185,142 @@ class HybridOp(Operation):
def axis(self): def axis(self):
"""Represent axis of IterVar, also defined when it is a HybridOp""" """Represent axis of IterVar, also defined when it is a HybridOp"""
return self.__getattr__("axis") return self.__getattr__("axis")
@register_node
class Layout(NodeBase):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
Do not construct directly, use :any:`layout` instead.
See the documentation of :any:`layout` for more details.
See Also
--------
layout : Declare a layout
"""
def __str__(self):
return self.name
def __repr__(self):
return "Layout(" + self.name + ")"
def __len__(self):
return _api_internal._LayoutNdim(self)
def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Layout index out of range")
return _api_internal._LayoutGetItem(self, index)
def index_of(self, axis):
"""Get the index of an axis
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
index : int
The index of the axis, -1 if not found.
"""
return _api_internal._LayoutIndexOf(self, axis)
def factor_of(self, axis):
"""Get the factor size of the subordinate axis.
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
factor : int
the size of the subordinate-axis of axis (if axis is a primal-axis),
or the size of axis itself (if axis is a subordinate-axis).
Return -1 if axis is not in the layout.
"""
return _api_internal._LayoutFactorOf(self, axis)
@register_node
class BijectiveLayout(NodeBase):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
Do not construct directly, use :any:`bijective_layout` instead.
See the documentation of :any:`bijective_layout` for more details.
See Also
--------
bijective_layout : Declare a bijective layout converter
"""
def forward_index(self, index):
"""Given the indices of the src-layout, infer the dst index.
Parameters
----------
index: Array of Expr
The indices in src-layout.
Returns
-------
dst_index: Array of Expr
The inferred indices in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardIndex(self, index)
def backward_index(self, index):
"""Given the indices of the dst-layout, infer the src index.
Parameters
----------
index: Array of Expr
The indices in dst-layout.
Returns
-------
src_index: Array of Expr
The inferred indices in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardIndex(self, index)
def forward_shape(self, shape):
"""Given the shape of the src-layout, infer the dst shape.
Parameters
----------
shape: Array of Expr
The shape in src-layout.
Returns
-------
dst_shape: Array of Expr
The inferred shape in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardShape(self, shape)
def backward_shape(self, shape):
"""Given the shape of the dst-layout, infer the src shape.
Parameters
----------
shape: Array of Expr
The shape in dst-layout.
Returns
-------
src_shape: Array of Expr
The inferred shape in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardShape(self, shape)
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/data_layout.h>
namespace tvm { namespace tvm {
...@@ -224,6 +225,63 @@ TVM_REGISTER_API("_BufferVStore") ...@@ -224,6 +225,63 @@ TVM_REGISTER_API("_BufferVStore")
.vstore(args[1], args[2]); .vstore(args[1], args[2]);
}); });
TVM_REGISTER_API("_Layout")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LayoutNode::make(args[0]);
});
TVM_REGISTER_API("_LayoutIndexOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Layout()
.IndexOf(LayoutAxis::make(args[1]));
});
TVM_REGISTER_API("_LayoutFactorOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Layout()
.FactorOf(LayoutAxis::make(args[1]));
});
TVM_REGISTER_API("_LayoutNdim")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(args[0].operator Layout().ndim());
});
TVM_REGISTER_API("_LayoutGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
const LayoutAxis& axis = args[0].operator Layout()[args[1]];
*ret = axis.name();
});
TVM_REGISTER_API("_BijectiveLayout")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BijectiveLayoutNode::make(args[0], args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.ForwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.BackwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.ForwardShape(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.BackwardShape(args[1]);
});
TVM_REGISTER_API("_Tensor") TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0], *ret = TensorNode::make(args[0],
......
/*!
* Copyright (c) 2019 by Contributors
* \file src/lang/data_layout.cc
* \brief Data Layout expression.
*/
#include <tvm/data_layout.h>
#include <tvm/ir_pass.h>
namespace tvm {
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode);
const LayoutAxis LayoutAxis::UPPER_CASE[] = {
LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'),
LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'),
LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'),
LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'),
LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'),
LayoutAxis('Z')
};
const LayoutAxis LayoutAxis::LOWER_CASE[] = {
LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'),
LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'),
LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'),
LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'),
LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'),
LayoutAxis('z')
};
const LayoutAxis& LayoutAxis::Get(const char name) {
CHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z'))
<< "Invalid layout axis name: " << name << ". Has to be A-Z or a-z.";
return (name >= 'A' && name <= 'Z') ?
LayoutAxis::UPPER_CASE[name-'A'] :
LayoutAxis::LOWER_CASE[name-'a'];
}
const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) {
const std::string axis = itvar->var.get()->name_hint;
CHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis;
return LayoutAxis::Get(axis[0]);
}
const LayoutAxis& LayoutAxis::make(const std::string& name) {
CHECK_EQ(name.length(), 1) << "Invalid axis " << name;
return LayoutAxis::Get(name[0]);
}
Layout::Layout(const Array<IterVar>& axes) {
node_ = make_node<LayoutNode>();
LayoutNode *node = operator->();
node->axes = axes;
std::ostringstream repr;
for (const IterVar& axis : axes) {
if (const auto* factor = axis->dom->extent.as<IntImm>()) {
CHECK_GT(factor->value, 0);
repr << factor->value;
}
CHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis "
<< axis->var.get()->name_hint;
char c = axis->var.get()->name_hint[0];
CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c;
repr << axis->var.get()->name_hint;
}
node->name = repr.str();
}
Layout::Layout(const std::string& name) { // NOLINT(*)
if (name.empty() || name == "__undef__") return;
node_ = make_node<LayoutNode>();
LayoutNode *node = operator->();
node->name = name;
// parse layout string
int32_t factor = 0;
for (char c : name) {
if (c >= 'A' && c <= 'Z') {
CHECK_EQ(factor, 0) << "Invalid layout " << name
<< ": invalid factor size " << factor
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
IterVar axis = IterVarNode::make(Range(Expr(0), Var(shape_name)),
Var(std::string(1, c)), kDataPar);
node->axes.push_back(axis);
} else if (c >= 'a' && c <= 'z') {
CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
<< factor << " for dimension " << c;
IterVar axis = IterVarNode::make(Range(Expr(0), Expr(factor)),
Var(std::string(1, c)), kDataPar);
node->axes.push_back(axis);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << name;
}
}
// validate layout
std::vector<bool> exist_axis(256, false);
for (const IterVar& v : node->axes) {
auto axis_str = v->var.get()->name_hint;
CHECK_EQ(axis_str.size(), 1);
char axis = axis_str[0];
CHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
CHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis;
exist_axis[axis] = true;
}
for (const IterVar& v : node->axes) {
char axis = v->var.get()->name_hint[0];
if (axis >= 'a' && axis <= 'z') {
CHECK(exist_axis[axis-'a'+'A']) << "Invalid layout " << name << ": missing axis "
<< axis - 'a' + 'A';
}
}
}
Layout LayoutNode::make(const std::string& layout) {
return Layout(layout);
}
Layout Layout::SubLayout(size_t pos, size_t len) const {
if (!defined() || pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
Array<IterVar> new_layout;
const auto axes = operator->()->axes;
for (size_t i = pos; i < pos + len; ++i) {
new_layout.push_back(axes[i]);
}
return Layout(new_layout);
}
Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const {
if (!defined()) return Layout::Undef();
const std::string& name = operator->()->name;
const auto axes = operator->()->axes;
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name;
CHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis;
CHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name;
CHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis
<< " has already been split in " << name;
CHECK(factor > 0) << "Invalid split size " << factor;
Array<IterVar> new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout.push_back(IterVarNode::make(Range(Expr(0), Expr(factor)),
Var(axis.ToSubordinate().name()), kDataPar));
}
if (i == this->ndim()) break;
new_layout.push_back(axes[i]);
}
return Layout(new_layout);
}
int32_t Layout::FactorOf(const LayoutAxis& axis) const {
if (!defined()) return -1;
const LayoutAxis& sub = axis.ToSubordinate();
if (!this->defined()) return -1;
for (const IterVar& itvar : operator->()->axes) {
if (sub == LayoutAxis::Get(itvar)) {
const auto* factor = itvar->dom->extent.as<IntImm>();
CHECK(factor);
return factor->value;
}
}
return -1;
}
inline bool GetStoreRule(Array<Expr>* rule,
const Layout& src_layout,
const Layout& dst_layout) {
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
Expr store(0);
for (size_t j = 0; j < src_layout.ndim(); ++j) {
const auto& orig_axis = src_layout[j];
const IterVar& orig_axis_impl = src_layout->axes[j];
if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
if (orig_axis.IsPrimal()) {
Expr orig_var = orig_axis_impl->var;
const int32_t factor = src_layout.FactorOf(orig_axis);
if (factor > 0) {
orig_var = orig_var * Expr(factor);
}
store = store + orig_var;
} else {
store = store + orig_axis_impl->var;
}
}
}
if (is_zero(store)) {
// Not convertible
return false;
}
if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) {
store = store / Expr(factor);
}
} else {
store = store % store_axis_impl->dom->extent;
}
rule->push_back(store);
}
return true;
}
inline Array<Expr> TransformIndex(const Array<Expr>& src_index,
const Array<IterVar>& src_axis,
const Array<Expr>& transform_rule) {
Array<Expr> result;
std::unordered_map<const Variable*, Expr> bind_map;
for (size_t i = 0; i < src_index.size(); ++i) {
bind_map[src_axis[i]->var.get()] = src_index[i];
}
for (Expr rule : transform_rule) {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
}
return result;
}
Array<Expr> BijectiveLayout::ForwardIndex(const Array<Expr>& src_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(src_index.size(), self->src_layout->axes.size())
<< "Input mismatch with layout " << self->src_layout;
return TransformIndex(src_index, self->src_layout->axes, self->forward_rule);
}
Array<Expr> BijectiveLayout::BackwardIndex(const Array<Expr>& dst_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
<< "Output mismatch with layout " << self->dst_layout;
return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
}
inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
const Array<IterVar>& src_axis,
const Array<IterVar>& target_axis,
const Array<Expr>& transform_rule) {
CHECK_EQ(src_shape.size(), src_axis.size());
// bind variables for original axes
// for major-axis, bind the corresponding size
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
std::unordered_map<const Variable*, Expr> bind_map;
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
if (orig_shape.defined()) {
const auto* orig_shape_const = orig_shape.as<IntImm>();
const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImm>();
CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
<< "Input shape mismatch at index " << i << ". Expected "
<< orig_axis->dom->extent << ", get " << orig_shape;
}
bind_map[orig_axis->var.get()] = Expr(0);
} else {
bind_map[orig_axis->var.get()] = orig_shape;
}
}
// infer the target shape,
// for major-axis, use the forward/backward_rule directly,
// for minor-axis, simply use the extent.
Array<Expr> result;
CHECK_EQ(transform_rule.size(), target_axis.size());
for (size_t i = 0; i < transform_rule.size(); ++i) {
Expr rule = transform_rule[i];
IterVar axis = target_axis[i];
if (!LayoutAxis::Get(axis).IsPrimal()) {
result.push_back(axis->dom->extent);
} else {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
}
}
return result;
}
Array<Expr> BijectiveLayout::ForwardShape(const Array<Expr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->src_layout->axes,
self->dst_layout->axes, self->forward_rule);
}
Array<Expr> BijectiveLayout::BackwardShape(const Array<Expr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->dst_layout->axes,
self->src_layout->axes, self->backward_rule);
}
BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
const Layout& dst_layout) {
auto n = make_node<BijectiveLayoutNode>();
n->src_layout = src_layout;
n->dst_layout = dst_layout;
if (!GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) {
// not convertible
return BijectiveLayout();
}
CHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout));
return BijectiveLayout(n);
}
} // namespace tvm
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
* \brief Property def of nn operators. * \brief Property def of nn operators.
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/debug.h> #include <tvm/relay/attrs/debug.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <vector> #include <vector>
#include "./type_relations.h" #include "./type_relations.h"
#include "./op_common.h" #include "./op_common.h"
#include "./layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
* \file resize.cc * \file resize.cc
* \brief Image operators * \brief Image operators
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/image/resize.h> #include <topi/image/resize.h>
#include "../layout.h"
#include "../op_common.h" #include "../op_common.h"
namespace tvm { namespace tvm {
...@@ -28,17 +28,18 @@ bool ResizeRel(const Array<Type>& types, ...@@ -28,17 +28,18 @@ bool ResizeRel(const Array<Type>& types,
const ResizeAttrs* param = attrs.as<ResizeAttrs>(); const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->layout); const Layout in_layout(param->layout);
CHECK(in_layout.Convertible(kNCHW)) auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "Resize only support input layouts that are convertible from NCHW." << "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW); auto oshape = layout_converter.ForwardShape(data->shape);
oshape[2] = param->size[0]; oshape.Set(2, param->size[0]);
oshape[3] = param->size[1]; oshape.Set(3, param->size[1]);
// assign output type // assign output type
reporter->Assign(types[1], reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout), TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype)); data->dtype));
return true; return true;
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/relay/op/layout.cc
* \brief Layout expression.
*/
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(LayoutNode);
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout) {
CHECK_EQ(src_layout.ndim(), src.size());
if (src_layout == dst_layout) {
return src;
} else if (!src_layout.defined()) {
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
} else if (!dst_layout.defined()) {
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
}
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from "
<< src_layout << " to " << dst_layout;
std::vector<IndexExpr> dst(dst_layout.ndim());
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_dim = src_layout[i];
if (Layout::IsSuperdim(src_dim)) {
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_dim));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_dim));
int src_minor_pos = src_layout.Indexof(Layout::ToSubdim(src_dim));
int src_factor = src_layout.Subsizeof(src_dim);
int dst_factor = dst_layout.Subsizeof(src_dim);
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
src_dim_size *= src_factor;
}
dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
CHECK_LE(dst_factor, const_src_dim_size[0])
<< "Converting " << Array<IndexExpr>(src)
<< " from " << src_layout
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
}
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
}
}
}
return dst;
}
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout) {
std::vector<IndexExpr> ret(src.size());
for (size_t i = 0; i < src.size(); ++i) {
ret[i] = src[i];
}
return ConvertLayout(ret, src_layout, dst_layout);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file relay/op/layout.h
* \brief Layout expression.
*
* This file is adapted from its nnvm counterpart and will keep involving
* to the new layout system
*
* The layout is composed of upper cases, lower cases and numbers,
* where upper case indicates a (super-)dimension and
* the corresponding lower case with factor size indicates the split (sub-)dimension.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
*/
#ifndef TVM_RELAY_OP_LAYOUT_H_
#define TVM_RELAY_OP_LAYOUT_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/relay/base.h>
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
namespace tvm {
namespace relay {
class LayoutNode : public Node {
public:
std::string name;
Array<Integer> superdim_pos;
Array<Integer> subdim_pos;
Array<Integer> subdim_size;
Array<Integer> layout_simplified;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("superdim_pos", &superdim_pos);
v->Visit("subdim_pos", &subdim_pos);
v->Visit("subdim_size", &subdim_size);
v->Visit("layout_simplified", &layout_simplified);
}
static constexpr const char* _type_key = "Layout";
TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
};
class Layout : public NodeRef {
public:
using LayoutDim = char;
static constexpr uint32_t kUniqueDim = 26;
explicit Layout(NodePtr<Node> n) : NodeRef(n) {}
/*! \brief default constructor */
Layout() : Layout("__undef__") {} // NOLINT(*)
/*! \brief construct from a string */
Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param layout input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
Layout(const std::string& name) { // NOLINT(*)
node_ = make_node<LayoutNode>();
std::vector<int32_t> superdim_pos(kUniqueDim, -1);
std::vector<int32_t> subdim_pos(kUniqueDim, -1);
std::vector<int32_t> subdim_size(kUniqueDim, -1);
std::vector<char> layout_simplified;
if (name != "__undef__") { // parse layout string
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < name.size(); ++i) {
const LayoutDim c = name.at(i);
if (IsSuperdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << name
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
superdim_pos[pos] = curr++;
layout_simplified.push_back(c);
} else if (IsSubdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
subdim_pos[pos] = curr++;
subdim_size[pos] = factor;
layout_simplified.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << name;
}
}
for (LayoutDim dim : layout_simplified) {
CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0)
<< "Invalid layout " << name << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
LayoutNode *node = operator->();
node->name = name;
for (uint32_t i = 0; i < kUniqueDim; ++i) {
node->superdim_pos.push_back(superdim_pos[i]);
node->subdim_pos.push_back(subdim_pos[i]);
node->subdim_size.push_back(subdim_size[i]);
}
for (LayoutDim dim : layout_simplified) {
node->layout_simplified.push_back(dim);
}
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
const LayoutNode* operator->() const {
return static_cast<const LayoutNode*>(node_.get());
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
LayoutNode* operator->() {
return static_cast<LayoutNode*>(node_.get());
}
/*!
* \brief Check whether a given dimension is a super-dimension.
* \param dim input dimension
* \return Whether a given dimension is a super-dimension.
*/
static bool IsSuperdim(LayoutDim dim) {
return dim >= 'A' && dim <= 'Z';
}
/*!
* \brief Check whether a given dimension is a sub-dimension.
* \param dim input dimension
* \return Whether a given dimension is a sub-dimension.
*/
static bool IsSubdim(LayoutDim dim) {
return dim >= 'a' && dim <= 'z';
}
/*!
* \brief Convert a given dimension to super-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim ToSuperdim(LayoutDim dim) {
if (IsSubdim(dim)) {
return dim - 'a' + 'A';
}
return dim;
}
/*!
* \brief Convert a given dimension to sub-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim ToSubdim(LayoutDim dim) {
if (IsSuperdim(dim)) {
return dim - 'A' + 'a';
}
return dim;
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Two layouts are convertible only if
* they have same set of super-dimensions.
* e.g., NCHW, NCHW16c, NHWC are convertible between each other,
* but NCHW, CHW, OIHW are not.
* \param dst the target layout
* \return Whether can be converted to dst layout.
*/
bool Convertible(const Layout &dst) const {
const LayoutNode *n = operator->();
if (!this->defined() || !dst.defined()) return false;
for (size_t i = 0; i < kUniqueDim; ++i) {
if ((n->superdim_pos[i]->value >= 0 && dst->superdim_pos[i]->value < 0) ||
(n->superdim_pos[i]->value < 0 && dst->superdim_pos[i]->value >= 0)) {
return false;
}
}
return true;
}
/*!
* \brief Returns a sublayout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
Layout Sublayout(size_t pos, size_t len) const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) {
if (IsSubdim(layout_simplified[i]->value)) {
auto block_size = this->Subsizeof(layout_simplified[i]->value);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << static_cast<char>(layout_simplified[i]->value);
}
return Layout(new_layout.str());
}
/*! \return A newly constructed reversed Layout object. */
Layout Reverse() const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
if (!this->defined()) return Layout::Undef();
std::ostringstream new_layout;
for (int64_t i = this->ndim() - 1; i >= 0; --i) {
if (IsSubdim(layout_simplified[i]->value)) {
auto block_size = this->Subsizeof(layout_simplified[i]->value);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified[i]->value;
}
return Layout(new_layout.str());
}
/*!
* \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
* \param dim The source dimension to be split. It must be a super-dimension.
* \param target_pos The target position of the newly split sub-dimension.
* \param size size of the sub-dimension.
* \return A newly constructed Layout object.
*/
Layout Split(LayoutDim dim, size_t target_pos, uint32_t size) const {
const std::string &name = operator->()->name;
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name;
CHECK(IsSuperdim(dim)) << "Cannot split a sub-dimension " << dim;
CHECK(this->Contains(dim)) << "Axis " << dim << " does not exist in " << name;
CHECK(!this->Contains(ToSubdim(dim))) << "Dimension " << dim
<< " has already been split in "
<< name;
CHECK(size > 0) << "Invalid split size " << size;
std::ostringstream new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout << size << Layout::ToSubdim(dim);
}
if (i == this->ndim()) break;
new_layout << this->at(i);
}
Layout x(new_layout.str());
return x;
}
/*! \return number of dimensions */
size_t ndim() const {
return operator->()->layout_simplified.size();
}
/*! \return number of super dimensions */
size_t ndim_super() const {
size_t ct = 0;
for (auto x : operator->()->layout_simplified) {
if (IsSuperdim(x))
ct++;
}
return ct;
}
/*!
* \brief The description of the \p i-th dimension.
* If it is a sub-dimension, the size will be returned as well,
* e.g., 16c. Otherwise a single character is returned, e.g., C.
* \param i The position
* \return the description of the dimension.
*/
std::string at(size_t i) const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
CHECK_LT(i, this->ndim()) << "position " << i
<< " exceeds ndim=" << this->ndim();
std::ostringstream repr;
if (IsSubdim(layout_simplified[i]->value)) {
auto factor = Subsizeof(layout_simplified[i]->value);
CHECK_GT(factor, 0);
repr << factor;
}
repr << static_cast<char>(layout_simplified[i]->value);
return repr.str();
}
/*!
* \brief return the index of the input dimension.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param dim the input dimension.
* \return the index or -1 if not found.
*/
int32_t Indexof(LayoutDim dim) const {
if (!this->defined()) return -1;
else if (IsSuperdim(dim)) return operator->()->superdim_pos[dim - 'A']->value;
else if (IsSubdim(dim)) return operator->()->subdim_pos[dim - 'a']->value;
return -1;
}
/*!
* \param dim the input super-dimension or sub-dimension.
* \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
* or the size of \p dim itself (if \p dim is a sub-dimension).
* Return -1 if \p dim is not in the layout or the layout is undefined.
*/
int64_t Subsizeof(LayoutDim dim) const {
CHECK(IsSuperdim(dim) || IsSubdim(dim)) << "Invalid dim " << dim;
if (!this->defined() || !this->Contains(ToSubdim(dim))) {
return -1;
}
int idx = ToSubdim(dim) - 'a';
return operator->()->subdim_size[idx]->value;
}
/*!
* \brief Whether the layout contains a dimension.
* \param dim dimension to be checked.
* \return Whether the layout contains the dimension.
*/
bool Contains(LayoutDim dim) const {
if (IsSuperdim(dim)) {
return operator->()->superdim_pos[dim-'A']->value >= 0;
} else if (IsSubdim(dim)) {
return operator->()->subdim_pos[dim-'a']->value >= 0;
}
return false;
}
LayoutDim operator[](size_t i) const {
return operator->()->layout_simplified[i];
}
/*! \return whether the layout is defined */
bool defined() const {
return operator->()->name != "__undef__";
}
/*! \return the string description of the layout */
const std::string& name() const {
return operator->()->name;
}
/*!
* \brief Whether the two layouts are equal.
* \param rhs Another layout.
* \return whether the two layouts are equal.
*/
bool Equals(const Layout &rhs) const {
return operator->()->name == rhs->name;
}
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name();
return os;
}
using ContainerType = LayoutNode;
};
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout);
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_LAYOUT_H_
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
* \file convolution.cc * \file convolution.cc
* \brief Convolution operators * \brief Convolution operators
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -31,32 +31,36 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -31,32 +31,36 @@ bool Conv2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout); const Layout kernel_layout(param->kernel_layout);
CHECK(in_layout.Convertible(kNCHW))
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW." << "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
CHECK(kernel_layout.Convertible(kOIHW))
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW." << "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
CHECK(out_layout.Convertible(kNCHW)) const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW." << "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout; << " But got " << out_layout;
std::vector<IndexExpr> dshape_nchw = ConvertLayout( Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
data->shape, in_layout, kNCHW);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x; IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined // infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) { if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape( Array<IndexExpr> wshape(
{param->channels, {param->channels,
dshape_nchw[1] / param->groups, dshape_nchw[1] / param->groups,
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout); wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels; channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
...@@ -65,7 +69,7 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -65,7 +69,7 @@ bool Conv2DRel(const Array<Type>& types,
} else { } else {
// use weight to infer the conv shape. // use weight to infer the conv shape.
if (weight == nullptr) return false; if (weight == nullptr) return false;
auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW); auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) { if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->kernel_size.size(), 2);
// check the size // check the size
...@@ -73,13 +77,13 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -73,13 +77,13 @@ bool Conv2DRel(const Array<Type>& types,
reporter->AssertEQ(param->kernel_size[1], wshape[3])) reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, " << "Conv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape); << " wshape=" << wshape;
} }
if (param->channels.defined()) { if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0])) CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "Conv2D: shape of weight is inconsistent with channels, " << "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape); << " wshape=" << wshape;
} }
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1])); CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
channels = wshape[0]; channels = wshape[0];
...@@ -87,15 +91,15 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -87,15 +91,15 @@ bool Conv2DRel(const Array<Type>& types,
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
} }
// dilation // dilation
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1; oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1; oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) { if (out_dtype.bits() == 0) {
out_dtype = data->dtype; out_dtype = data->dtype;
} }
oshape = ConvertLayout(oshape, kNCHW, out_layout); oshape = trans_out_layout.BackwardShape(oshape);
// assign output type // assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true; return true;
...@@ -193,33 +197,38 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -193,33 +197,38 @@ bool Conv2DTransposeRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout); const Layout kernel_layout(param->kernel_layout);
CHECK(in_layout.Convertible(kNCHW))
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW." << "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
CHECK(kernel_layout.Convertible(kOIHW))
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW." << "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
CHECK(out_layout.Convertible(kNCHW)) const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW." << "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout; << " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x; IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW); auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
// infer weight if the kernel_size and channels are defined // infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) { if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape({dshape_nchw[1], Array<IndexExpr> wshape({dshape_nchw[1],
param->channels / param->groups, param->channels / param->groups,
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout); wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
channels = param->channels; channels = param->channels;
...@@ -229,7 +238,7 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -229,7 +238,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
} else { } else {
// use weight to infer the conv shape. // use weight to infer the conv shape.
if (weight == nullptr) return false; if (weight == nullptr) return false;
auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW); auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) { if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->kernel_size.size(), 2);
// check the size // check the size
...@@ -251,17 +260,17 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -251,17 +260,17 @@ bool Conv2DTransposeRel(const Array<Type>& types,
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
} }
// dilation // dilation
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
2 * param->padding[0] + param->output_padding[0]); 2 * param->padding[0] + param->output_padding[0]));
oshape[3] = (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1]); 2 * param->padding[1] + param->output_padding[1]));
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) { if (out_dtype.bits() == 0) {
out_dtype = data->dtype; out_dtype = data->dtype;
} }
oshape = ConvertLayout(oshape, kNCHW, out_layout); oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true; return true;
} }
...@@ -349,20 +358,24 @@ bool Conv2DWinogradRel(const Array<Type>& types, ...@@ -349,20 +358,24 @@ bool Conv2DWinogradRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout); const Layout kernel_layout(param->kernel_layout);
CHECK(in_layout.Convertible(kNCHW))
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW." << "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
CHECK(kernel_layout.Convertible(kOIHW))
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW." << "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
CHECK(out_layout.Convertible(kNCHW)) const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW." << "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout; << " But got " << out_layout;
std::vector<IndexExpr> dshape_nchw = ConvertLayout( Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
data->shape, in_layout, kNCHW);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x; IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
...@@ -384,15 +397,15 @@ bool Conv2DWinogradRel(const Array<Type>& types, ...@@ -384,15 +397,15 @@ bool Conv2DWinogradRel(const Array<Type>& types,
// can handle this correctly in alter_op_layout. // can handle this correctly in alter_op_layout.
// dilation // dilation
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1; oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1; oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) { if (out_dtype.bits() == 0) {
out_dtype = data->dtype; out_dtype = data->dtype;
} }
oshape = ConvertLayout(oshape, kNCHW, out_layout); oshape = trans_out_layout.BackwardShape(oshape);
// assign output type // assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true; return true;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Property def of nn operators. * \brief Property def of nn operators.
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
...@@ -14,7 +15,6 @@ ...@@ -14,7 +15,6 @@
#include "../type_relations.h" #include "../type_relations.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
#include "../op_common.h" #include "../op_common.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
* \file pad.cc * \file pad.cc
* \brief Implementation of operator pad * \brief Implementation of operator pad
*/ */
#include <tvm/data_layout.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn.h> #include <topi/nn.h>
#include <vector> #include <vector>
#include "../layout.h"
#include "../op_common.h" #include "../op_common.h"
namespace tvm { namespace tvm {
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
* \file pooling.cc * \file pooling.cc
* \brief Pooling operators * \brief Pooling operators
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h> #include <topi/nn/pooling.h>
#include <vector> #include <vector>
#include "../layout.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
namespace tvm { namespace tvm {
...@@ -32,14 +32,15 @@ Array<Array<Layout> > Pool2DInferCorrectLayout( ...@@ -32,14 +32,15 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
Layout raw_layout(params->layout); Layout raw_layout(params->layout);
Layout input = new_in_layouts[0]; Layout input = new_in_layouts[0];
if (input.Indexof('W') == raw_layout.Indexof('W') && if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.Indexof('H') == raw_layout.Indexof('H') && input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains('w') && !input.Contains('h')) { !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout params->layout = input.name(); // modify self to follow the input layout
} }
} }
return Array<Array<Layout> >{{params->layout}, {params->layout}}; Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
} }
template <typename AttrType> template <typename AttrType>
...@@ -59,13 +60,13 @@ bool Pool2DRel(const Array<Type>& types, ...@@ -59,13 +60,13 @@ bool Pool2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.Contains('H') && layout.Contains('W') && CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains('h') && !layout.Contains('w')) !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout << "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split"; << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.Indexof('H'); const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.Indexof('W'); const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
IndexExpr pad_h, pad_w; IndexExpr pad_h, pad_w;
if (param->padding.size() == 1) { if (param->padding.size() == 1) {
...@@ -125,6 +126,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -125,6 +126,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
const Target& target) { const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>(); const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr); CHECK(param != nullptr);
auto pool_size = param->pool_size; auto pool_size = param->pool_size;
...@@ -132,10 +134,13 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -132,10 +134,13 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
auto padding = param->padding; auto padding = param->padding;
auto ceil_mode = param->ceil_mode; auto ceil_mode = param->ceil_mode;
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.Convertible(Layout("NCHW")))
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "max_pool2d currently only supports layouts that are convertible from NCHW"; << "max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.Indexof('h'), -1) << "max_pool2d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
CHECK_EQ(layout.Indexof('w'), -1) << "max_pool2d does not support input split on width"; << "max_pool2d does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)" << "Pool2D only support 4-D input (e.g., NCHW)"
...@@ -271,13 +276,13 @@ bool GlobalPool2DRel(const Array<Type>& types, ...@@ -271,13 +276,13 @@ bool GlobalPool2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.Contains('H') && layout.Contains('W') && CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains('h') && !layout.Contains('w')) !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout << "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split"; << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.Indexof('H'); const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.Indexof('W'); const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
Array<IndexExpr> oshape(dshape); Array<IndexExpr> oshape(dshape);
oshape.Set(hidx, 1); oshape.Set(hidx, 1);
oshape.Set(widx, 1); oshape.Set(widx, 1);
...@@ -293,14 +298,15 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs, ...@@ -293,14 +298,15 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
const Target& target) { const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<GlobalPool2DAttrs>(); const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.Convertible(Layout("NCHW"))) CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; << "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.Indexof('h'), -1) CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "global_avg_pool2d does not support input split on height"; << "global_avg_pool2d does not support input split on height";
CHECK_EQ(layout.Indexof('w'), -1) CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "global_avg_pool2d does not support input split on width"; << "global_avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file upsampling.cc * \file upsampling.cc
* \brief upsampling operator * \brief upsampling operator
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
...@@ -11,7 +12,6 @@ ...@@ -11,7 +12,6 @@
#include <topi/nn/upsampling.h> #include <topi/nn/upsampling.h>
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -31,18 +31,20 @@ bool UpSamplingRel(const Array<Type>& types, ...@@ -31,18 +31,20 @@ bool UpSamplingRel(const Array<Type>& types,
const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>(); const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->layout); const Layout in_layout(param->layout);
CHECK(in_layout.Convertible(kNCHW))
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "UpSampling only support input layouts that are convertible from NCHW." << "UpSampling only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW); auto oshape = layout_converter.ForwardShape(data->shape);
oshape[2] = oshape[2] * param->scale; oshape.Set(2, oshape[2] * param->scale);
oshape[3] = oshape[3] * param->scale; oshape.Set(3, oshape[3] * param->scale);
// assign output type // assign output type
reporter->Assign(types[1], reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout), TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype)); data->dtype));
return true; return true;
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/data_layout.h>
#include <topi/transform.h> #include <topi/transform.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/broadcast.h> #include <topi/broadcast.h>
...@@ -16,7 +17,6 @@ ...@@ -16,7 +17,6 @@
#include "../op_common.h" #include "../op_common.h"
#include "../../../arithmetic/compute_expr.h" #include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -218,7 +218,7 @@ Array<Array<Layout>> ConcatenateLayout( ...@@ -218,7 +218,7 @@ Array<Array<Layout>> ConcatenateLayout(
Layout ret; Layout ret;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated. if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
Layout::LayoutDim concate_dim = old_in_layouts[0][axis]; const auto& concate_dim = old_in_layouts[0][axis];
for (size_t i = 0; i < new_in_layouts.size(); ++i) { for (size_t i = 0; i < new_in_layouts.size(); ++i) {
if (new_in_layouts[i].ndim() > axis && if (new_in_layouts[i].ndim() > axis &&
new_in_layouts[i][axis] == concate_dim) { new_in_layouts[i][axis] == concate_dim) {
...@@ -234,7 +234,7 @@ Array<Array<Layout>> ConcatenateLayout( ...@@ -234,7 +234,7 @@ Array<Array<Layout>> ConcatenateLayout(
} }
} }
if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) { if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}}; return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
} }
} }
...@@ -1682,46 +1682,10 @@ Array<Tensor> LayoutTransformCompute(const Attrs& attrs, ...@@ -1682,46 +1682,10 @@ Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
const Target& target) { const Target& target) {
const LayoutTransformAttrs *param = attrs.as<LayoutTransformAttrs>(); const auto* param = attrs.as<LayoutTransformAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
return Array<Tensor>{
Layout src_layout(param->src_layout); topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
Layout dst_layout(param->dst_layout);
if (src_layout.Equals(dst_layout)) {
return Array<Tensor>{ inputs[0] };
}
CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from " << param->src_layout << " to " << param->dst_layout;
const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout);
return Array<Tensor> {
topi::layout_transform(inputs[0], out_shape, [&](const Array<tvm::Var>& dst_indices) {
std::vector<tvm::Expr> dst_to_src_indices;
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_axis = src_layout[i];
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.Subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.Subsizeof(src_axis));
tvm::Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::IsSuperdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::IsSubdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<tvm::Expr>(dst_to_src_indices);
})
}; };
} }
...@@ -1738,10 +1702,12 @@ bool LayoutTransformRel(const Array<Type>& types, ...@@ -1738,10 +1702,12 @@ bool LayoutTransformRel(const Array<Type>& types,
CHECK(src_layout.defined() && dst_layout.defined()) CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout"; << "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
auto layout_converter = BijectiveLayoutNode::make(src_layout, dst_layout);
CHECK(layout_converter.defined())
<< "cannot convert from " << params->src_layout << " to " << params->dst_layout; << "cannot convert from " << params->src_layout << " to " << params->dst_layout;
const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout); const auto& out_shape = layout_converter.ForwardShape(data->shape);
reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype)); reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
return true; return true;
} }
......
...@@ -26,7 +26,7 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { ...@@ -26,7 +26,7 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
if (src_layout.Equals(dst_layout)) { return raw; } if (src_layout.Equals(dst_layout)) { return raw; }
CHECK(src_layout.defined() && dst_layout.defined()) CHECK(src_layout.defined() && dst_layout.defined())
<< "Cannot insert layout transform because there are undefined layouts"; << "Cannot insert layout transform because there are undefined layouts";
CHECK(src_layout.Convertible(dst_layout)) CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined())
<< "Cannot insert layout transform because there are inconvertible layouts: " << "Cannot insert layout transform because there are inconvertible layouts: "
<< src_layout << " v.s. " << dst_layout; << src_layout << " v.s. " << dst_layout;
static auto &transform_op = Op::Get("layout_transform"); static auto &transform_op = Op::Get("layout_transform");
......
...@@ -9,10 +9,9 @@ ...@@ -9,10 +9,9 @@
#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ #ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ #define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#include <tvm/data_layout.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include "../op/layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -78,9 +77,9 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs, ...@@ -78,9 +77,9 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx, layouts.Set(undef_idx,
layouts[defined_idx].Sublayout( layouts[defined_idx].SubLayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size())); old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}}; return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
} else { } else {
// only know the tensor with smaller dimensions, // only know the tensor with smaller dimensions,
...@@ -90,21 +89,22 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs, ...@@ -90,21 +89,22 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
} }
} else { } else {
// try to broadcast the tensors to the larger dimension // try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1; int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
int small_idx = 1 - large_idx; int small_idx = 1 - large_idx;
Layout ret = layouts[large_idx]; Layout ret = layouts[large_idx];
// extract common part // extract common part
size_t i = layouts[large_idx].ndim(); size_t i = layouts[large_idx].ndim();
for (; i != 0; --i) { for (; i != 0; --i) {
auto dim = layouts[large_idx][i-1]; const auto& axis = layouts[large_idx][i-1];
if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) { if (!layouts[small_idx].Contains(axis.ToPrimal())) {
break; break;
} }
} }
Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i); Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i);
if (!layouts[small_idx].Convertible(common_part)) { // fail if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) {
// not convertible
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}}; return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
} }
......
...@@ -91,8 +91,10 @@ class BranchGroupFinder : private ExprVisitor { ...@@ -91,8 +91,10 @@ class BranchGroupFinder : private ExprVisitor {
CHECK(attrs_b); CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>(); const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>(); const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->kernel_layout, kOIHW); const auto shape_a = BijectiveLayoutNode::make(
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->kernel_layout, kOIHW); Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
const auto shape_b = BijectiveLayoutNode::make(
Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
......
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
* \brief Fold axis scaling into weights of * \brief Fold axis scaling into weights of
* conv/dense operators. * conv/dense operators.
*/ */
#include <tvm/data_layout.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "pattern_util.h" #include "pattern_util.h"
#include "pass_util.h" #include "pass_util.h"
#include "../op/layout.h"
namespace tvm { namespace tvm {
...@@ -435,8 +435,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) { ...@@ -435,8 +435,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C'); int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = data_layout.Indexof('c'); int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>(); Message none = NullValue<Message>();
...@@ -449,7 +449,7 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) { ...@@ -449,7 +449,7 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.Indexof('i') < 0 && if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis}; data_axes = {c_big_axis};
...@@ -473,15 +473,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call, ...@@ -473,15 +473,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C'); int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(sdata->axes.size() == 1 && CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value); c_big_axis == sdata->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O'); int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int big_ic_axis = kernel_layout.Indexof('I'); int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
...@@ -857,8 +857,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) ...@@ -857,8 +857,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
CHECK(param != nullptr); CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C'); int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = out_layout.Indexof('c'); int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
...@@ -869,8 +869,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) ...@@ -869,8 +869,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.Indexof('o') < 0 && if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
kernel_layout.Indexof('i') < 0 && kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
return MessageNode::make({c_big_axis}, false); return MessageNode::make({c_big_axis}, false);
...@@ -891,16 +891,16 @@ Expr Conv2DBackwardTransform(const Call& call, ...@@ -891,16 +891,16 @@ Expr Conv2DBackwardTransform(const Call& call,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C'); int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('o'), -1); CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(message->axes.size() == 1 && CHECK(message->axes.size() == 1 &&
c_big_axis == message->axes[0]->value); c_big_axis == message->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O'); int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d); CHECK(param->groups == 1 || is_depthwise_conv2d);
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "../op/layout.h" #include <tvm/data_layout.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -51,8 +51,8 @@ int64_t ConvMacCount(const Call& call_node) { ...@@ -51,8 +51,8 @@ int64_t ConvMacCount(const Call& call_node) {
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>(); const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape; Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_attr->data_layout; std::string data_layout = conv_2d_attr->data_layout;
int32_t C_ind = Layout(data_layout).Indexof('C'); int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
int32_t c_ind = Layout(data_layout).Indexof('c'); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
CHECK(C_ind != -1) CHECK(C_ind != -1)
<< "There is no input channel dimension."; << "There is no input channel dimension.";
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value); int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
......
...@@ -8,13 +8,13 @@ ...@@ -8,13 +8,13 @@
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_ #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_ #define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <string> #include <string>
#include "../op/layout.h"
namespace tvm { namespace tvm {
...@@ -155,9 +155,8 @@ inline bool IsDepthwiseConv2D(const Call& call, ...@@ -155,9 +155,8 @@ inline bool IsDepthwiseConv2D(const Call& call,
const Conv2DAttrs* param, const Conv2DAttrs* param,
const Layout& kernel_layout) { const Layout& kernel_layout) {
static const Layout kOIHW("OIHW"); static const Layout kOIHW("OIHW");
auto wshape = ConvertLayout( const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
call->args[1]->type_as<TensorTypeNode>()->shape, auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
kernel_layout, kOIHW);
return is_const_int(wshape[0], param->groups) && return is_const_int(wshape[0], param->groups) &&
is_const_int(wshape[1], 1); is_const_int(wshape[1], 1);
} }
......
"""Test layout and bijective-layout node"""
import tvm
from topi.util import get_const_tuple
def test_layout():
layout = tvm.layout("NCHW16c")
assert layout is not None
assert isinstance(layout, tvm.tensor.Layout)
assert layout.factor_of("c") == 16
assert layout.factor_of("C") == 16
assert layout.factor_of("N") == -1
assert layout.index_of("N") == 0
assert layout.index_of("C") == 1
assert layout.index_of("H") == 2
assert layout.index_of("W") == 3
assert layout.index_of("c") == 4
assert layout.index_of("O") == -1
assert "N" in layout
assert "C" in layout
assert "H" in layout
assert "W" in layout
assert "c" in layout
assert "O" not in layout
assert layout[0] == "N"
assert layout[1] == "C"
assert layout[2] == "H"
assert layout[3] == "W"
assert layout[4] == "c"
assert layout[-1] == "c"
def test_bilayout_convertible():
# not convertible
assert tvm.bijective_layout("NCHW", "ABCD") is None
# convertible
assert tvm.bijective_layout("NCHW", "NCHW16c") is not None
def test_bilayout_shape():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
assert isinstance(bilayout, tvm.tensor.BijectiveLayout)
dst_shape = bilayout.forward_shape((1, 32, 7, 7))
assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
src_shape = bilayout.backward_shape(dst_shape)
assert get_const_tuple(src_shape) == (1, 32, 7, 7)
def test_bilayout_index():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
dst_index = bilayout.forward_index([0, 18, 6, 6])
assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2)
src_index = bilayout.backward_index([0, 1, 6, 6, 2])
assert get_const_tuple(src_index) == (0, 18, 6, 6)
if __name__ == "__main__":
test_layout()
test_bilayout_convertible()
test_bilayout_shape()
test_bilayout_index()
...@@ -450,28 +450,5 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, ...@@ -450,28 +450,5 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
return tvm::compute(output_shape, l, name, tag); return tvm::compute(output_shape, l, name, tag);
} }
using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
/*!
* \brief Transform the layout according to the mapping function \p to_src_indices.
* \param src the source input.
* \param dst_shape the output shape.
* \param to_src_indices the mapping function from input index to output index.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape \p dst_shape.
*/
inline Tensor layout_transform(const Tensor& src,
const Array<Expr>& dst_shape,
const FLayoutIndicesTransform& to_src_indices,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
auto src_shape = src->shape;
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
return src(to_src_indices(dst_indices));
}, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_NN_H_ #endif // TOPI_NN_H_
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "topi/detail/ravel_unravel.h" #include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h" #include "topi/detail/constant_utils.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
#include "tvm/data_layout.h"
namespace topi { namespace topi {
using namespace tvm; using namespace tvm;
...@@ -882,5 +883,43 @@ inline Tensor arange(const Expr start, ...@@ -882,5 +883,43 @@ inline Tensor arange(const Expr start,
}, name, tag); }, name, tag);
} }
/*!
* \brief Transform the layout according to \p src_layout and \p dst_layout
* \param src the source input.
* \param src_layout the source layout.
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor layout_transform(const Tensor& src,
const std::string& src_layout,
const std::string& dst_layout,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
Layout src_layout_struct = LayoutNode::make(src_layout);
Layout dst_layout_struct = LayoutNode::make(dst_layout);
if (src_layout_struct.Equals(dst_layout_struct)) {
return src;
}
CHECK(src_layout_struct.defined() && dst_layout_struct.defined())
<< "cannot convert from/to undefined layout";
auto layout_converter = BijectiveLayoutNode::make(src_layout_struct, dst_layout_struct);
CHECK(layout_converter.defined())
<< "cannot convert from " << src_layout << " to " << dst_layout;
Array<Expr> dst_shape = layout_converter.ForwardShape(src->shape);
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
Array<Expr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<Expr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
return src(src_indices);
}, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_TRANSFORM_H_ #endif // TOPI_TRANSFORM_H_
...@@ -318,3 +318,20 @@ def arange(start, stop=None, step=1, dtype="float32"): ...@@ -318,3 +318,20 @@ def arange(start, stop=None, step=1, dtype="float32"):
stop = start stop = start
start = 0 start = 0
return cpp.arange(start, stop, step, dtype) return cpp.arange(start, stop, step, dtype)
def layout_transform(array, src_layout, dst_layout):
"""Transform the layout according to src_layout and dst_layout
Parameters
----------
array : tvm.Tensor
The source array.
src_layout : str
the source layout.
dst_layout : str
the destination layout.
"""
return cpp.layout_transform(array, src_layout, dst_layout)
...@@ -272,6 +272,11 @@ TVM_REGISTER_GLOBAL("topi.split") ...@@ -272,6 +272,11 @@ TVM_REGISTER_GLOBAL("topi.split")
} }
}); });
TVM_REGISTER_GLOBAL("topi.layout_transform")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = layout_transform(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.take") TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) { if (args.size() == 2) {
......
...@@ -449,6 +449,34 @@ def test_arange(): ...@@ -449,6 +449,34 @@ def test_arange():
verify_arange(20, 1, -1.5) verify_arange(20, 1, -1.5)
def test_layout_transform():
in_shape = (1, 32, 8, 8)
A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
B = topi.layout_transform(A, "NCHW", "NCHW16c")
input = np.random.uniform(size=in_shape).astype(A.dtype)
output = np.transpose(input, axes=(0, 2, 3, 1))
output = np.reshape(output, newshape=(1, 8, 8, 2, 16))
output = np.transpose(output, axes=(0, 3, 1, 2, 4))
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
tvm_input = tvm.nd.array(input, ctx)
tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype)
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
f = tvm.build(s, [A, B], device, name="layout_transform")
f(tvm_input, tvm_output)
tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
for backend in get_all_backend():
check_device(backend)
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice() test_strided_slice()
test_concatenate() test_concatenate()
...@@ -462,3 +490,4 @@ if __name__ == "__main__": ...@@ -462,3 +490,4 @@ if __name__ == "__main__":
test_take() test_take()
test_gather_nd() test_gather_nd()
test_arange() test_arange()
test_layout_transform()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment