Commit ee79703c by Yizhi Liu Committed by Lianmin Zheng

[Lang] Layout in TVM node system (#2509)

* move layout.h & layout.cc from relay to tvm

* change ConvertLayout in relay to bijectiveLayout->Forward/backward

* add first test case

* add LayoutAxis

* add LayoutAxis struct and compiles

* simplify BijectiveLayout rule consturct

* polish func name for Layout, move impl to .cc, remove Layout::defined(), add defined() checker

* partially add layout py support

* add layout test cases

* add doc for tvm.layout & tvm.bijective_layout

* fix lint

* fix lint

* fix layout name generation bug

* fix layout typo

* address comments and add topi.layout_transform

* layout.h->data_layout.h, test_lang_layout.py->test_lang_data_layout.py
parent ddc31fd4
......@@ -68,6 +68,7 @@ List of operators
topi.greater_equal
topi.less_equal
topi.arange
topi.layout_transform
topi.image.resize
......@@ -125,6 +126,7 @@ topi
.. autofunction:: topi.greater
.. autofunction:: topi.less
.. autofunction:: topi.arange
.. autofunction:: topi.layout_transform
topi.nn
~~~~~~~
......
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/data_layout.h
* \brief Layout expression to describe the data organization of a tensor.
* And BijectiveLayout to mapping two data layouts between each other.
*/
#ifndef TVM_DATA_LAYOUT_H_
#define TVM_DATA_LAYOUT_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
#include "ir_operator.h"
namespace tvm {
class LayoutAxis {
public:
static const LayoutAxis& Get(const char name);
// Get the singleton LayoutAxis using itvar->var->name_hint
static const LayoutAxis& Get(const IterVar& itvar);
// Get the singleton LayoutAxis using name[0] (size of name must be 1).
static const LayoutAxis& make(const std::string& name);
inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
inline std::string name() const { return std::string(1, name_); }
// if current axis is primal, switch the axis to its subordinate one,
// else switch to the primal.
inline const LayoutAxis& ToDual() const {
if (name_ >= 'A' && name_ <= 'Z') {
return LayoutAxis::Get(name_ - 'A' + 'a');
} else {
return LayoutAxis::Get(name_ - 'a' + 'A');
}
}
// return the primal axis. If it is already primal, return itself.
const LayoutAxis& ToPrimal() const {
return IsPrimal() ? *this : ToDual();
}
// return the subordinate axis. If it is already subordinate, return itself.
const LayoutAxis& ToSubordinate() const {
return IsPrimal() ? ToDual() : *this;
}
inline bool operator==(const LayoutAxis& rhs) const {
return name_ == rhs.name_;
}
friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
os << l.name();
return os;
}
private:
static const LayoutAxis UPPER_CASE[];
static const LayoutAxis LOWER_CASE[];
LayoutAxis(const LayoutAxis&);
LayoutAxis& operator=(const LayoutAxis&);
explicit LayoutAxis(const char name) : name_(name) {}
const char name_;
};
class Layout;
// Internal node container Buffer
class LayoutNode : public Node {
public:
/*! \brief string representation of layout */
std::string name;
/*! \brief specify each axis of the layout,
* in which the variable name is the name of the axis.
* The IterVar's extent indicates the size of the axis,
* it is a variable for a primal axis, but a constant for a subordinate axis.
*/
Array<IterVar> axes;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("axes", &axes);
}
TVM_DLL static Layout make(const std::string& layout);
static constexpr const char* _type_key = "Layout";
TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
};
/*!
* \brief Layout is to describe how data is organized within an N-dimention tensor.
* It is composed of upper cases, lower cases and numbers,
* where upper case indicates a primal axis and
* the corresponding lower case with factor size indicates the subordinate axis.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
*/
class Layout : public NodeRef {
public:
explicit Layout(NodePtr<Node> n) : NodeRef(n) {}
/*! \brief default constructor */
Layout() = default;
explicit Layout(const Array<IterVar>& axes);
/*! \brief construct from a string */
Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param name input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
Layout(const std::string& name); // NOLINT(*)
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
const LayoutNode* operator->() const {
return static_cast<const LayoutNode*>(node_.get());
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
LayoutNode* operator->() {
return static_cast<LayoutNode*>(node_.get());
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Returns a sub-layout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
Layout SubLayout(size_t pos, size_t len) const;
/*!
* \brief Split \p axis by \p size and put the sub-axis to position \p target_pos.
* \param axis The source axis to be split. It must be a primal-axis;
* \param target_pos The target position of the newly split subordinate-axis.
* \param factor size of the sub-dimension.
* \return A newly constructed Layout object.
*/
Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const;
/*! \return number of dimensions */
inline size_t ndim() const {
if (!defined()) return 0;
return operator->()->axes.size();
}
/*! \return number of super dimensions */
inline size_t ndim_primal() const {
if (!defined()) return 0;
size_t ct = 0;
for (auto x : operator->()->axes) {
if (LayoutAxis::Get(x).IsPrimal()) {
ct++;
}
}
return ct;
}
/*!
* \brief return the index of the input axis.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param axis the input axis.
* \return the index or -1 if not found.
*/
inline int32_t IndexOf(const LayoutAxis& axis) const {
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
}
return -1;
}
/*!
* \brief Get the factor size of the subordinate axis.
* \param axis the input primal-axis or subordinate-axis.
* \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis),
* or the size of \p axis itself (if \p axis is a subordinate-axis).
* Return -1 if \p axis is not in the layout the layout is undefined.
*/
int32_t FactorOf(const LayoutAxis& axis) const;
/*!
* \brief Whether the layout contains an axis.
* \param axis axis to be checked.
* \return Whether the layout contains the axis.
*/
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
for (const IterVar var : operator->()->axes) {
if (var->var.get()->name_hint == axis.name()) {
return true;
}
}
return false;
}
const LayoutAxis& operator[](int32_t i) const {
CHECK(defined()) << "Try to access axis from an undefined layout.";
int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
const IterVar axis = operator->()->axes[index];
return LayoutAxis::Get(axis);
}
/*! \return the string description of the layout */
inline std::string name() const {
if (!defined()) return "__undef__";
return operator->()->name;
}
/*!
* \brief Whether the two layouts are equal.
* \param rhs Another layout.
* \return whether the two layouts are equal.
*/
inline bool Equals(const Layout &rhs) const {
return name() == rhs.name();
}
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name();
return os;
}
using ContainerType = LayoutNode;
};
class BijectiveLayout;
// Internal node container BijectiveLayout
class BijectiveLayoutNode : public Node {
public:
/*! \brief Describes how source axes can be mapped to the destination axes,
* e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
*/
Array<Expr> forward_rule;
/*! \brief Describes how destination axes can be mapped to the source axes */
Array<Expr> backward_rule;
/*! \brief The source layout */
Layout src_layout;
/*! \brief The destination layout */
Layout dst_layout;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("src_layout", &src_layout);
v->Visit("dst_layout", &dst_layout);
v->Visit("forward_rule", &forward_rule);
v->Visit("backward_rule", &backward_rule);
}
static constexpr const char* _type_key = "BijectiveLayout";
TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node);
TVM_DLL static BijectiveLayout make(const Layout& src_layout,
const Layout& dst_layout);
};
/*! \brief Bijective function mapping for data layout transformation.
* Given two Layout, BijectiveLayout build and store the mapping rules,
* provides API to transform N-dimention tensor from the source indices (i0, i1, …, im)
* to the destination indices (j0, j1, … jm).
*/
class BijectiveLayout : public NodeRef {
public:
BijectiveLayout() = default;
explicit BijectiveLayout(NodePtr<Node> n) : NodeRef(n) {}
// Given the source shape, infer the destination shape.
TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const;
// Given the destination shape, recover the source shape.
TVM_DLL Array<Expr> BackwardShape(const Array<Expr>& dst_shape) const;
// Given the destination indices, infer the destination indices.
TVM_DLL Array<Expr> ForwardIndex(const Array<Expr>& index) const;
// Given the destination indices, recover the source indices.
TVM_DLL Array<Expr> BackwardIndex(const Array<Expr>& dst_index) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const BijectiveLayoutNode* operator->() const;
/*! \brief specify container node */
using ContainerType = BijectiveLayoutNode;
};
inline const BijectiveLayoutNode* BijectiveLayout::operator->() const {
return static_cast<const BijectiveLayoutNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_DATA_LAYOUT_H_
......@@ -674,42 +674,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
const Array<Tensor>& inputs,
const Array<Tensor>& outputs) {
const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed);
Layout src_layout(param.src_layout);
Layout dst_layout(param.dst_layout);
if (src_layout == dst_layout) {
return Array<Tensor>{ inputs[0] };
} else if (!src_layout.defined() || !dst_layout.defined()) {
LOG(FATAL) << "cannot convert from/to undefined layout";
}
CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " << param.src_layout
<< " to " << param.dst_layout;
return Array<Tensor> {
topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array<Var>& dst_indices) {
std::vector<Expr> dst_to_src_indices;
for (Layout::LayoutDim src_axis : src_layout) {
int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis));
int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis));
Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::is_superdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::is_subdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<Expr>(dst_to_src_indices);
})
return Array<Tensor>{
topi::layout_transform(inputs[0], param.src_layout, param.dst_layout)
};
})
.set_support_level(1);
......
......@@ -515,7 +515,7 @@ def decl_buffer(shape,
scope="",
data_alignment=-1,
offset_factor=0):
"""Decleare a new symbolic buffer.
"""Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout.
......@@ -587,6 +587,49 @@ def decl_buffer(shape,
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor)
def layout(layout_str):
"""Create a layout node from a string.
Parameters
----------
layout_str : str
A layout representation is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).
Returns
-------
layout : Layout
The created layout
"""
return _api_internal._Layout(layout_str)
def bijective_layout(src_layout, dst_layout):
"""Create a bijective layout mapping.
Parameters
----------
src_layout : str or Layout
source layout.
dst_layout : str or Layout
destination layout.
Returns
-------
bijective_layout : BijectiveLayout
The created bijective layout
"""
if isinstance(src_layout, str):
src_layout = layout(src_layout)
if isinstance(dst_layout, str):
dst_layout = layout(dst_layout)
return _api_internal._BijectiveLayout(src_layout, dst_layout)
def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar
......
......@@ -185,3 +185,142 @@ class HybridOp(Operation):
def axis(self):
"""Represent axis of IterVar, also defined when it is a HybridOp"""
return self.__getattr__("axis")
@register_node
class Layout(NodeBase):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
Do not construct directly, use :any:`layout` instead.
See the documentation of :any:`layout` for more details.
See Also
--------
layout : Declare a layout
"""
def __str__(self):
return self.name
def __repr__(self):
return "Layout(" + self.name + ")"
def __len__(self):
return _api_internal._LayoutNdim(self)
def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Layout index out of range")
return _api_internal._LayoutGetItem(self, index)
def index_of(self, axis):
"""Get the index of an axis
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
index : int
The index of the axis, -1 if not found.
"""
return _api_internal._LayoutIndexOf(self, axis)
def factor_of(self, axis):
"""Get the factor size of the subordinate axis.
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
factor : int
the size of the subordinate-axis of axis (if axis is a primal-axis),
or the size of axis itself (if axis is a subordinate-axis).
Return -1 if axis is not in the layout.
"""
return _api_internal._LayoutFactorOf(self, axis)
@register_node
class BijectiveLayout(NodeBase):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
Do not construct directly, use :any:`bijective_layout` instead.
See the documentation of :any:`bijective_layout` for more details.
See Also
--------
bijective_layout : Declare a bijective layout converter
"""
def forward_index(self, index):
"""Given the indices of the src-layout, infer the dst index.
Parameters
----------
index: Array of Expr
The indices in src-layout.
Returns
-------
dst_index: Array of Expr
The inferred indices in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardIndex(self, index)
def backward_index(self, index):
"""Given the indices of the dst-layout, infer the src index.
Parameters
----------
index: Array of Expr
The indices in dst-layout.
Returns
-------
src_index: Array of Expr
The inferred indices in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardIndex(self, index)
def forward_shape(self, shape):
"""Given the shape of the src-layout, infer the dst shape.
Parameters
----------
shape: Array of Expr
The shape in src-layout.
Returns
-------
dst_shape: Array of Expr
The inferred shape in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardShape(self, shape)
def backward_shape(self, shape):
"""Given the shape of the dst-layout, infer the src shape.
Parameters
----------
shape: Array of Expr
The shape in dst-layout.
Returns
-------
src_shape: Array of Expr
The inferred shape in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardShape(self, shape)
......@@ -11,6 +11,7 @@
#include <tvm/schedule.h>
#include <tvm/api_registry.h>
#include <tvm/build_module.h>
#include <tvm/data_layout.h>
namespace tvm {
......@@ -224,6 +225,63 @@ TVM_REGISTER_API("_BufferVStore")
.vstore(args[1], args[2]);
});
TVM_REGISTER_API("_Layout")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LayoutNode::make(args[0]);
});
TVM_REGISTER_API("_LayoutIndexOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Layout()
.IndexOf(LayoutAxis::make(args[1]));
});
TVM_REGISTER_API("_LayoutFactorOf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Layout()
.FactorOf(LayoutAxis::make(args[1]));
});
TVM_REGISTER_API("_LayoutNdim")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(args[0].operator Layout().ndim());
});
TVM_REGISTER_API("_LayoutGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
const LayoutAxis& axis = args[0].operator Layout()[args[1]];
*ret = axis.name();
});
TVM_REGISTER_API("_BijectiveLayout")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BijectiveLayoutNode::make(args[0], args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.ForwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.BackwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.ForwardShape(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator BijectiveLayout()
.BackwardShape(args[1]);
});
TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0],
......
/*!
* Copyright (c) 2019 by Contributors
* \file src/lang/data_layout.cc
* \brief Data Layout expression.
*/
#include <tvm/data_layout.h>
#include <tvm/ir_pass.h>
namespace tvm {
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode);
const LayoutAxis LayoutAxis::UPPER_CASE[] = {
LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'),
LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'),
LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'),
LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'),
LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'),
LayoutAxis('Z')
};
const LayoutAxis LayoutAxis::LOWER_CASE[] = {
LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'),
LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'),
LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'),
LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'),
LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'),
LayoutAxis('z')
};
const LayoutAxis& LayoutAxis::Get(const char name) {
CHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z'))
<< "Invalid layout axis name: " << name << ". Has to be A-Z or a-z.";
return (name >= 'A' && name <= 'Z') ?
LayoutAxis::UPPER_CASE[name-'A'] :
LayoutAxis::LOWER_CASE[name-'a'];
}
const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) {
const std::string axis = itvar->var.get()->name_hint;
CHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis;
return LayoutAxis::Get(axis[0]);
}
const LayoutAxis& LayoutAxis::make(const std::string& name) {
CHECK_EQ(name.length(), 1) << "Invalid axis " << name;
return LayoutAxis::Get(name[0]);
}
Layout::Layout(const Array<IterVar>& axes) {
node_ = make_node<LayoutNode>();
LayoutNode *node = operator->();
node->axes = axes;
std::ostringstream repr;
for (const IterVar& axis : axes) {
if (const auto* factor = axis->dom->extent.as<IntImm>()) {
CHECK_GT(factor->value, 0);
repr << factor->value;
}
CHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis "
<< axis->var.get()->name_hint;
char c = axis->var.get()->name_hint[0];
CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c;
repr << axis->var.get()->name_hint;
}
node->name = repr.str();
}
Layout::Layout(const std::string& name) { // NOLINT(*)
if (name.empty() || name == "__undef__") return;
node_ = make_node<LayoutNode>();
LayoutNode *node = operator->();
node->name = name;
// parse layout string
int32_t factor = 0;
for (char c : name) {
if (c >= 'A' && c <= 'Z') {
CHECK_EQ(factor, 0) << "Invalid layout " << name
<< ": invalid factor size " << factor
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
IterVar axis = IterVarNode::make(Range(Expr(0), Var(shape_name)),
Var(std::string(1, c)), kDataPar);
node->axes.push_back(axis);
} else if (c >= 'a' && c <= 'z') {
CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
<< factor << " for dimension " << c;
IterVar axis = IterVarNode::make(Range(Expr(0), Expr(factor)),
Var(std::string(1, c)), kDataPar);
node->axes.push_back(axis);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << name;
}
}
// validate layout
std::vector<bool> exist_axis(256, false);
for (const IterVar& v : node->axes) {
auto axis_str = v->var.get()->name_hint;
CHECK_EQ(axis_str.size(), 1);
char axis = axis_str[0];
CHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
CHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis;
exist_axis[axis] = true;
}
for (const IterVar& v : node->axes) {
char axis = v->var.get()->name_hint[0];
if (axis >= 'a' && axis <= 'z') {
CHECK(exist_axis[axis-'a'+'A']) << "Invalid layout " << name << ": missing axis "
<< axis - 'a' + 'A';
}
}
}
Layout LayoutNode::make(const std::string& layout) {
return Layout(layout);
}
Layout Layout::SubLayout(size_t pos, size_t len) const {
if (!defined() || pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
Array<IterVar> new_layout;
const auto axes = operator->()->axes;
for (size_t i = pos; i < pos + len; ++i) {
new_layout.push_back(axes[i]);
}
return Layout(new_layout);
}
Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const {
if (!defined()) return Layout::Undef();
const std::string& name = operator->()->name;
const auto axes = operator->()->axes;
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name;
CHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis;
CHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name;
CHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis
<< " has already been split in " << name;
CHECK(factor > 0) << "Invalid split size " << factor;
Array<IterVar> new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout.push_back(IterVarNode::make(Range(Expr(0), Expr(factor)),
Var(axis.ToSubordinate().name()), kDataPar));
}
if (i == this->ndim()) break;
new_layout.push_back(axes[i]);
}
return Layout(new_layout);
}
int32_t Layout::FactorOf(const LayoutAxis& axis) const {
if (!defined()) return -1;
const LayoutAxis& sub = axis.ToSubordinate();
if (!this->defined()) return -1;
for (const IterVar& itvar : operator->()->axes) {
if (sub == LayoutAxis::Get(itvar)) {
const auto* factor = itvar->dom->extent.as<IntImm>();
CHECK(factor);
return factor->value;
}
}
return -1;
}
inline bool GetStoreRule(Array<Expr>* rule,
const Layout& src_layout,
const Layout& dst_layout) {
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
Expr store(0);
for (size_t j = 0; j < src_layout.ndim(); ++j) {
const auto& orig_axis = src_layout[j];
const IterVar& orig_axis_impl = src_layout->axes[j];
if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
if (orig_axis.IsPrimal()) {
Expr orig_var = orig_axis_impl->var;
const int32_t factor = src_layout.FactorOf(orig_axis);
if (factor > 0) {
orig_var = orig_var * Expr(factor);
}
store = store + orig_var;
} else {
store = store + orig_axis_impl->var;
}
}
}
if (is_zero(store)) {
// Not convertible
return false;
}
if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) {
store = store / Expr(factor);
}
} else {
store = store % store_axis_impl->dom->extent;
}
rule->push_back(store);
}
return true;
}
inline Array<Expr> TransformIndex(const Array<Expr>& src_index,
const Array<IterVar>& src_axis,
const Array<Expr>& transform_rule) {
Array<Expr> result;
std::unordered_map<const Variable*, Expr> bind_map;
for (size_t i = 0; i < src_index.size(); ++i) {
bind_map[src_axis[i]->var.get()] = src_index[i];
}
for (Expr rule : transform_rule) {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
}
return result;
}
Array<Expr> BijectiveLayout::ForwardIndex(const Array<Expr>& src_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(src_index.size(), self->src_layout->axes.size())
<< "Input mismatch with layout " << self->src_layout;
return TransformIndex(src_index, self->src_layout->axes, self->forward_rule);
}
Array<Expr> BijectiveLayout::BackwardIndex(const Array<Expr>& dst_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
<< "Output mismatch with layout " << self->dst_layout;
return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
}
inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
const Array<IterVar>& src_axis,
const Array<IterVar>& target_axis,
const Array<Expr>& transform_rule) {
CHECK_EQ(src_shape.size(), src_axis.size());
// bind variables for original axes
// for major-axis, bind the corresponding size
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
std::unordered_map<const Variable*, Expr> bind_map;
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
if (orig_shape.defined()) {
const auto* orig_shape_const = orig_shape.as<IntImm>();
const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImm>();
CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
<< "Input shape mismatch at index " << i << ". Expected "
<< orig_axis->dom->extent << ", get " << orig_shape;
}
bind_map[orig_axis->var.get()] = Expr(0);
} else {
bind_map[orig_axis->var.get()] = orig_shape;
}
}
// infer the target shape,
// for major-axis, use the forward/backward_rule directly,
// for minor-axis, simply use the extent.
Array<Expr> result;
CHECK_EQ(transform_rule.size(), target_axis.size());
for (size_t i = 0; i < transform_rule.size(); ++i) {
Expr rule = transform_rule[i];
IterVar axis = target_axis[i];
if (!LayoutAxis::Get(axis).IsPrimal()) {
result.push_back(axis->dom->extent);
} else {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
}
}
return result;
}
Array<Expr> BijectiveLayout::ForwardShape(const Array<Expr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->src_layout->axes,
self->dst_layout->axes, self->forward_rule);
}
Array<Expr> BijectiveLayout::BackwardShape(const Array<Expr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->dst_layout->axes,
self->src_layout->axes, self->backward_rule);
}
BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
const Layout& dst_layout) {
auto n = make_node<BijectiveLayoutNode>();
n->src_layout = src_layout;
n->dst_layout = dst_layout;
if (!GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) {
// not convertible
return BijectiveLayout();
}
CHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout));
return BijectiveLayout(n);
}
} // namespace tvm
......@@ -4,13 +4,13 @@
* \brief Property def of nn operators.
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/debug.h>
#include <topi/elemwise.h>
#include <vector>
#include "./type_relations.h"
#include "./op_common.h"
#include "./layout.h"
namespace tvm {
namespace relay {
......
......@@ -3,11 +3,11 @@
* \file resize.cc
* \brief Image operators
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h>
#include <topi/elemwise.h>
#include <topi/image/resize.h>
#include "../layout.h"
#include "../op_common.h"
namespace tvm {
......@@ -28,17 +28,18 @@ bool ResizeRel(const Array<Type>& types,
const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
CHECK(in_layout.Convertible(kNCHW))
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW);
oshape[2] = param->size[0];
oshape[3] = param->size[1];
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, param->size[0]);
oshape.Set(3, param->size[1]);
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout),
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/relay/op/layout.cc
* \brief Layout expression.
*/
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(LayoutNode);
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout) {
CHECK_EQ(src_layout.ndim(), src.size());
if (src_layout == dst_layout) {
return src;
} else if (!src_layout.defined()) {
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
} else if (!dst_layout.defined()) {
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
}
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from "
<< src_layout << " to " << dst_layout;
std::vector<IndexExpr> dst(dst_layout.ndim());
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_dim = src_layout[i];
if (Layout::IsSuperdim(src_dim)) {
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_dim));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_dim));
int src_minor_pos = src_layout.Indexof(Layout::ToSubdim(src_dim));
int src_factor = src_layout.Subsizeof(src_dim);
int dst_factor = dst_layout.Subsizeof(src_dim);
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
src_dim_size *= src_factor;
}
dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
CHECK_LE(dst_factor, const_src_dim_size[0])
<< "Converting " << Array<IndexExpr>(src)
<< " from " << src_layout
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
}
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
}
}
}
return dst;
}
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout) {
std::vector<IndexExpr> ret(src.size());
for (size_t i = 0; i < src.size(); ++i) {
ret[i] = src[i];
}
return ConvertLayout(ret, src_layout, dst_layout);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file relay/op/layout.h
* \brief Layout expression.
*
* This file is adapted from its nnvm counterpart and will keep involving
* to the new layout system
*
* The layout is composed of upper cases, lower cases and numbers,
* where upper case indicates a (super-)dimension and
* the corresponding lower case with factor size indicates the split (sub-)dimension.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
*/
#ifndef TVM_RELAY_OP_LAYOUT_H_
#define TVM_RELAY_OP_LAYOUT_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/relay/base.h>
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
namespace tvm {
namespace relay {
class LayoutNode : public Node {
public:
std::string name;
Array<Integer> superdim_pos;
Array<Integer> subdim_pos;
Array<Integer> subdim_size;
Array<Integer> layout_simplified;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("superdim_pos", &superdim_pos);
v->Visit("subdim_pos", &subdim_pos);
v->Visit("subdim_size", &subdim_size);
v->Visit("layout_simplified", &layout_simplified);
}
static constexpr const char* _type_key = "Layout";
TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
};
class Layout : public NodeRef {
public:
using LayoutDim = char;
static constexpr uint32_t kUniqueDim = 26;
explicit Layout(NodePtr<Node> n) : NodeRef(n) {}
/*! \brief default constructor */
Layout() : Layout("__undef__") {} // NOLINT(*)
/*! \brief construct from a string */
Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param layout input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
Layout(const std::string& name) { // NOLINT(*)
node_ = make_node<LayoutNode>();
std::vector<int32_t> superdim_pos(kUniqueDim, -1);
std::vector<int32_t> subdim_pos(kUniqueDim, -1);
std::vector<int32_t> subdim_size(kUniqueDim, -1);
std::vector<char> layout_simplified;
if (name != "__undef__") { // parse layout string
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < name.size(); ++i) {
const LayoutDim c = name.at(i);
if (IsSuperdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << name
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
superdim_pos[pos] = curr++;
layout_simplified.push_back(c);
} else if (IsSubdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
subdim_pos[pos] = curr++;
subdim_size[pos] = factor;
layout_simplified.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << name;
}
}
for (LayoutDim dim : layout_simplified) {
CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0)
<< "Invalid layout " << name << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
LayoutNode *node = operator->();
node->name = name;
for (uint32_t i = 0; i < kUniqueDim; ++i) {
node->superdim_pos.push_back(superdim_pos[i]);
node->subdim_pos.push_back(subdim_pos[i]);
node->subdim_size.push_back(subdim_size[i]);
}
for (LayoutDim dim : layout_simplified) {
node->layout_simplified.push_back(dim);
}
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
const LayoutNode* operator->() const {
return static_cast<const LayoutNode*>(node_.get());
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
LayoutNode* operator->() {
return static_cast<LayoutNode*>(node_.get());
}
/*!
* \brief Check whether a given dimension is a super-dimension.
* \param dim input dimension
* \return Whether a given dimension is a super-dimension.
*/
static bool IsSuperdim(LayoutDim dim) {
return dim >= 'A' && dim <= 'Z';
}
/*!
* \brief Check whether a given dimension is a sub-dimension.
* \param dim input dimension
* \return Whether a given dimension is a sub-dimension.
*/
static bool IsSubdim(LayoutDim dim) {
return dim >= 'a' && dim <= 'z';
}
/*!
* \brief Convert a given dimension to super-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim ToSuperdim(LayoutDim dim) {
if (IsSubdim(dim)) {
return dim - 'a' + 'A';
}
return dim;
}
/*!
* \brief Convert a given dimension to sub-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim ToSubdim(LayoutDim dim) {
if (IsSuperdim(dim)) {
return dim - 'A' + 'a';
}
return dim;
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Two layouts are convertible only if
* they have same set of super-dimensions.
* e.g., NCHW, NCHW16c, NHWC are convertible between each other,
* but NCHW, CHW, OIHW are not.
* \param dst the target layout
* \return Whether can be converted to dst layout.
*/
bool Convertible(const Layout &dst) const {
const LayoutNode *n = operator->();
if (!this->defined() || !dst.defined()) return false;
for (size_t i = 0; i < kUniqueDim; ++i) {
if ((n->superdim_pos[i]->value >= 0 && dst->superdim_pos[i]->value < 0) ||
(n->superdim_pos[i]->value < 0 && dst->superdim_pos[i]->value >= 0)) {
return false;
}
}
return true;
}
/*!
* \brief Returns a sublayout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
Layout Sublayout(size_t pos, size_t len) const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) {
if (IsSubdim(layout_simplified[i]->value)) {
auto block_size = this->Subsizeof(layout_simplified[i]->value);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << static_cast<char>(layout_simplified[i]->value);
}
return Layout(new_layout.str());
}
/*! \return A newly constructed reversed Layout object. */
Layout Reverse() const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
if (!this->defined()) return Layout::Undef();
std::ostringstream new_layout;
for (int64_t i = this->ndim() - 1; i >= 0; --i) {
if (IsSubdim(layout_simplified[i]->value)) {
auto block_size = this->Subsizeof(layout_simplified[i]->value);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified[i]->value;
}
return Layout(new_layout.str());
}
/*!
* \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
* \param dim The source dimension to be split. It must be a super-dimension.
* \param target_pos The target position of the newly split sub-dimension.
* \param size size of the sub-dimension.
* \return A newly constructed Layout object.
*/
Layout Split(LayoutDim dim, size_t target_pos, uint32_t size) const {
const std::string &name = operator->()->name;
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name;
CHECK(IsSuperdim(dim)) << "Cannot split a sub-dimension " << dim;
CHECK(this->Contains(dim)) << "Axis " << dim << " does not exist in " << name;
CHECK(!this->Contains(ToSubdim(dim))) << "Dimension " << dim
<< " has already been split in "
<< name;
CHECK(size > 0) << "Invalid split size " << size;
std::ostringstream new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout << size << Layout::ToSubdim(dim);
}
if (i == this->ndim()) break;
new_layout << this->at(i);
}
Layout x(new_layout.str());
return x;
}
/*! \return number of dimensions */
size_t ndim() const {
return operator->()->layout_simplified.size();
}
/*! \return number of super dimensions */
size_t ndim_super() const {
size_t ct = 0;
for (auto x : operator->()->layout_simplified) {
if (IsSuperdim(x))
ct++;
}
return ct;
}
/*!
* \brief The description of the \p i-th dimension.
* If it is a sub-dimension, the size will be returned as well,
* e.g., 16c. Otherwise a single character is returned, e.g., C.
* \param i The position
* \return the description of the dimension.
*/
std::string at(size_t i) const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
CHECK_LT(i, this->ndim()) << "position " << i
<< " exceeds ndim=" << this->ndim();
std::ostringstream repr;
if (IsSubdim(layout_simplified[i]->value)) {
auto factor = Subsizeof(layout_simplified[i]->value);
CHECK_GT(factor, 0);
repr << factor;
}
repr << static_cast<char>(layout_simplified[i]->value);
return repr.str();
}
/*!
* \brief return the index of the input dimension.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param dim the input dimension.
* \return the index or -1 if not found.
*/
int32_t Indexof(LayoutDim dim) const {
if (!this->defined()) return -1;
else if (IsSuperdim(dim)) return operator->()->superdim_pos[dim - 'A']->value;
else if (IsSubdim(dim)) return operator->()->subdim_pos[dim - 'a']->value;
return -1;
}
/*!
* \param dim the input super-dimension or sub-dimension.
* \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
* or the size of \p dim itself (if \p dim is a sub-dimension).
* Return -1 if \p dim is not in the layout or the layout is undefined.
*/
int64_t Subsizeof(LayoutDim dim) const {
CHECK(IsSuperdim(dim) || IsSubdim(dim)) << "Invalid dim " << dim;
if (!this->defined() || !this->Contains(ToSubdim(dim))) {
return -1;
}
int idx = ToSubdim(dim) - 'a';
return operator->()->subdim_size[idx]->value;
}
/*!
* \brief Whether the layout contains a dimension.
* \param dim dimension to be checked.
* \return Whether the layout contains the dimension.
*/
bool Contains(LayoutDim dim) const {
if (IsSuperdim(dim)) {
return operator->()->superdim_pos[dim-'A']->value >= 0;
} else if (IsSubdim(dim)) {
return operator->()->subdim_pos[dim-'a']->value >= 0;
}
return false;
}
LayoutDim operator[](size_t i) const {
return operator->()->layout_simplified[i];
}
/*! \return whether the layout is defined */
bool defined() const {
return operator->()->name != "__undef__";
}
/*! \return the string description of the layout */
const std::string& name() const {
return operator->()->name;
}
/*!
* \brief Whether the two layouts are equal.
* \param rhs Another layout.
* \return whether the two layouts are equal.
*/
bool Equals(const Layout &rhs) const {
return operator->()->name == rhs->name;
}
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name();
return os;
}
using ContainerType = LayoutNode;
};
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout);
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_LAYOUT_H_
......@@ -3,12 +3,12 @@
* \file convolution.cc
* \brief Convolution operators
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>
#include "../../pass/alter_op_layout.h"
#include "../layout.h"
namespace tvm {
namespace relay {
......@@ -31,32 +31,36 @@ bool Conv2DRel(const Array<Type>& types,
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
CHECK(in_layout.Convertible(kNCHW))
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
CHECK(kernel_layout.Convertible(kOIHW))
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
CHECK(out_layout.Convertible(kNCHW))
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
std::vector<IndexExpr> dshape_nchw = ConvertLayout(
data->shape, in_layout, kNCHW);
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape(
Array<IndexExpr> wshape(
{param->channels,
dshape_nchw[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
......@@ -65,7 +69,7 @@ bool Conv2DRel(const Array<Type>& types,
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
......@@ -73,13 +77,13 @@ bool Conv2DRel(const Array<Type>& types,
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
<< " wshape=" << wshape;
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
<< " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
channels = wshape[0];
......@@ -87,15 +91,15 @@ bool Conv2DRel(const Array<Type>& types,
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = ConvertLayout(oshape, kNCHW, out_layout);
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
......@@ -193,33 +197,38 @@ bool Conv2DTransposeRel(const Array<Type>& types,
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
CHECK(in_layout.Convertible(kNCHW))
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
CHECK(kernel_layout.Convertible(kOIHW))
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
CHECK(out_layout.Convertible(kNCHW))
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape({dshape_nchw[1],
Array<IndexExpr> wshape({dshape_nchw[1],
param->channels / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
channels = param->channels;
......@@ -229,7 +238,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
......@@ -251,17 +260,17 @@ bool Conv2DTransposeRel(const Array<Type>& types,
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
2 * param->padding[0] + param->output_padding[0]);
oshape[3] = (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1]);
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
2 * param->padding[0] + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1]));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = ConvertLayout(oshape, kNCHW, out_layout);
oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}
......@@ -349,20 +358,24 @@ bool Conv2DWinogradRel(const Array<Type>& types,
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
CHECK(in_layout.Convertible(kNCHW))
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
CHECK(kernel_layout.Convertible(kOIHW))
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
CHECK(out_layout.Convertible(kNCHW))
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
std::vector<IndexExpr> dshape_nchw = ConvertLayout(
data->shape, in_layout, kNCHW);
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
......@@ -384,15 +397,15 @@ bool Conv2DWinogradRel(const Array<Type>& types,
// can handle this correctly in alter_op_layout.
// dilation
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = ConvertLayout(oshape, kNCHW, out_layout);
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
......
......@@ -4,6 +4,7 @@
* \brief Property def of nn operators.
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/image.h>
......@@ -14,7 +15,6 @@
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
#include "../op_common.h"
#include "../layout.h"
namespace tvm {
namespace relay {
......
......@@ -3,12 +3,12 @@
* \file pad.cc
* \brief Implementation of operator pad
*/
#include <tvm/data_layout.h>
#include <tvm/ir_operator.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <topi/nn.h>
#include <vector>
#include "../layout.h"
#include "../op_common.h"
namespace tvm {
......
......@@ -3,12 +3,12 @@
* \file pooling.cc
* \brief Pooling operators
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h>
#include <vector>
#include "../layout.h"
#include "../../pass/alter_op_layout.h"
namespace tvm {
......@@ -32,14 +32,15 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.Indexof('W') == raw_layout.Indexof('W') &&
input.Indexof('H') == raw_layout.Indexof('H') &&
!input.Contains('w') && !input.Contains('h')) {
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
}
}
return Array<Array<Layout> >{{params->layout}, {params->layout}};
Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}
template <typename AttrType>
......@@ -59,13 +60,13 @@ bool Pool2DRel(const Array<Type>& types,
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Contains('H') && layout.Contains('W') &&
!layout.Contains('h') && !layout.Contains('w'))
CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.Indexof('H');
const auto widx = layout.Indexof('W');
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
IndexExpr pad_h, pad_w;
if (param->padding.size() == 1) {
......@@ -125,6 +126,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
auto pool_size = param->pool_size;
......@@ -132,10 +134,13 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
auto padding = param->padding;
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);
CHECK(layout.Convertible(Layout("NCHW")))
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.Indexof('h'), -1) << "max_pool2d does not support input split on height";
CHECK_EQ(layout.Indexof('w'), -1) << "max_pool2d does not support input split on width";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "max_pool2d does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
......@@ -271,13 +276,13 @@ bool GlobalPool2DRel(const Array<Type>& types,
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Contains('H') && layout.Contains('W') &&
!layout.Contains('h') && !layout.Contains('w'))
CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
!layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
<< "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.Indexof('H');
const auto widx = layout.Indexof('W');
const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
Array<IndexExpr> oshape(dshape);
oshape.Set(hidx, 1);
oshape.Set(widx, 1);
......@@ -293,14 +298,15 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.Convertible(Layout("NCHW")))
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.Indexof('h'), -1)
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "global_avg_pool2d does not support input split on height";
CHECK_EQ(layout.Indexof('w'), -1)
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "global_avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
......
......@@ -3,6 +3,7 @@
* \file upsampling.cc
* \brief upsampling operator
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op_attr_types.h>
......@@ -11,7 +12,6 @@
#include <topi/nn/upsampling.h>
#include <vector>
#include "../op_common.h"
#include "../layout.h"
namespace tvm {
namespace relay {
......@@ -31,18 +31,20 @@ bool UpSamplingRel(const Array<Type>& types,
const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
CHECK(in_layout.Convertible(kNCHW))
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "UpSampling only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW);
auto oshape = layout_converter.ForwardShape(data->shape);
oshape[2] = oshape[2] * param->scale;
oshape[3] = oshape[3] * param->scale;
oshape.Set(2, oshape[2] * param->scale);
oshape.Set(3, oshape[3] * param->scale);
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout),
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
......
......@@ -7,6 +7,7 @@
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h>
#include <tvm/ir.h>
#include <tvm/data_layout.h>
#include <topi/transform.h>
#include <topi/elemwise.h>
#include <topi/broadcast.h>
......@@ -16,7 +17,6 @@
#include "../op_common.h"
#include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h"
#include "../layout.h"
namespace tvm {
namespace relay {
......@@ -218,7 +218,7 @@ Array<Array<Layout>> ConcatenateLayout(
Layout ret;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
Layout::LayoutDim concate_dim = old_in_layouts[0][axis];
const auto& concate_dim = old_in_layouts[0][axis];
for (size_t i = 0; i < new_in_layouts.size(); ++i) {
if (new_in_layouts[i].ndim() > axis &&
new_in_layouts[i][axis] == concate_dim) {
......@@ -234,7 +234,7 @@ Array<Array<Layout>> ConcatenateLayout(
}
}
if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) {
if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
}
......@@ -1682,46 +1682,10 @@ Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const LayoutTransformAttrs *param = attrs.as<LayoutTransformAttrs>();
const auto* param = attrs.as<LayoutTransformAttrs>();
CHECK(param != nullptr);
Layout src_layout(param->src_layout);
Layout dst_layout(param->dst_layout);
if (src_layout.Equals(dst_layout)) {
return Array<Tensor>{ inputs[0] };
}
CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from " << param->src_layout << " to " << param->dst_layout;
const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout);
return Array<Tensor> {
topi::layout_transform(inputs[0], out_shape, [&](const Array<tvm::Var>& dst_indices) {
std::vector<tvm::Expr> dst_to_src_indices;
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_axis = src_layout[i];
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.Subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.Subsizeof(src_axis));
tvm::Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::IsSuperdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::IsSubdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<tvm::Expr>(dst_to_src_indices);
})
return Array<Tensor>{
topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
};
}
......@@ -1738,10 +1702,12 @@ bool LayoutTransformRel(const Array<Type>& types,
CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
auto layout_converter = BijectiveLayoutNode::make(src_layout, dst_layout);
CHECK(layout_converter.defined())
<< "cannot convert from " << params->src_layout << " to " << params->dst_layout;
const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout);
const auto& out_shape = layout_converter.ForwardShape(data->shape);
reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
return true;
}
......
......@@ -26,7 +26,7 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
if (src_layout.Equals(dst_layout)) { return raw; }
CHECK(src_layout.defined() && dst_layout.defined())
<< "Cannot insert layout transform because there are undefined layouts";
CHECK(src_layout.Convertible(dst_layout))
CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined())
<< "Cannot insert layout transform because there are inconvertible layouts: "
<< src_layout << " v.s. " << dst_layout;
static auto &transform_op = Op::Get("layout_transform");
......
......@@ -9,10 +9,9 @@
#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#include <tvm/data_layout.h>
#include <tvm/relay/expr.h>
#include "../op/layout.h"
namespace tvm {
namespace relay {
......@@ -78,7 +77,7 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx,
layouts[defined_idx].Sublayout(
layouts[defined_idx].SubLayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
......@@ -90,21 +89,22 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
}
} else {
// try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1;
int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
int small_idx = 1 - large_idx;
Layout ret = layouts[large_idx];
// extract common part
size_t i = layouts[large_idx].ndim();
for (; i != 0; --i) {
auto dim = layouts[large_idx][i-1];
if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) {
const auto& axis = layouts[large_idx][i-1];
if (!layouts[small_idx].Contains(axis.ToPrimal())) {
break;
}
}
Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i);
if (!layouts[small_idx].Convertible(common_part)) { // fail
Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i);
if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) {
// not convertible
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
......
......@@ -91,8 +91,10 @@ class BranchGroupFinder : private ExprVisitor {
CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->kernel_layout, kOIHW);
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->kernel_layout, kOIHW);
const auto shape_a = BijectiveLayoutNode::make(
Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
const auto shape_b = BijectiveLayoutNode::make(
Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
......
......@@ -6,12 +6,12 @@
* \brief Fold axis scaling into weights of
* conv/dense operators.
*/
#include <tvm/data_layout.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "pass_util.h"
#include "../op/layout.h"
namespace tvm {
......@@ -435,8 +435,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C');
int c_small_axis = data_layout.Indexof('c');
int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
CHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>();
......@@ -449,7 +449,7 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.Indexof('i') < 0 &&
if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
......@@ -473,15 +473,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C');
int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O');
int big_ic_axis = kernel_layout.Indexof('I');
int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
......@@ -857,8 +857,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C');
int c_small_axis = out_layout.Indexof('c');
int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
......@@ -869,8 +869,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.Indexof('o') < 0 &&
kernel_layout.Indexof('i') < 0 &&
if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return MessageNode::make({c_big_axis}, false);
......@@ -891,16 +891,16 @@ Expr Conv2DBackwardTransform(const Call& call,
CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C');
int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('o'), -1);
CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(message->axes.size() == 1 &&
c_big_axis == message->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O');
int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
......
......@@ -11,7 +11,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "../op/layout.h"
#include <tvm/data_layout.h>
namespace tvm {
namespace relay {
......@@ -51,8 +51,8 @@ int64_t ConvMacCount(const Call& call_node) {
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_attr->data_layout;
int32_t C_ind = Layout(data_layout).Indexof('C');
int32_t c_ind = Layout(data_layout).Indexof('c');
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
CHECK(C_ind != -1)
<< "There is no input channel dimension.";
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
......
......@@ -8,13 +8,13 @@
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/nn.h>
#include <string>
#include "../op/layout.h"
namespace tvm {
......@@ -155,9 +155,8 @@ inline bool IsDepthwiseConv2D(const Call& call,
const Conv2DAttrs* param,
const Layout& kernel_layout) {
static const Layout kOIHW("OIHW");
auto wshape = ConvertLayout(
call->args[1]->type_as<TensorTypeNode>()->shape,
kernel_layout, kOIHW);
const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
return is_const_int(wshape[0], param->groups) &&
is_const_int(wshape[1], 1);
}
......
"""Test layout and bijective-layout node"""
import tvm
from topi.util import get_const_tuple
def test_layout():
layout = tvm.layout("NCHW16c")
assert layout is not None
assert isinstance(layout, tvm.tensor.Layout)
assert layout.factor_of("c") == 16
assert layout.factor_of("C") == 16
assert layout.factor_of("N") == -1
assert layout.index_of("N") == 0
assert layout.index_of("C") == 1
assert layout.index_of("H") == 2
assert layout.index_of("W") == 3
assert layout.index_of("c") == 4
assert layout.index_of("O") == -1
assert "N" in layout
assert "C" in layout
assert "H" in layout
assert "W" in layout
assert "c" in layout
assert "O" not in layout
assert layout[0] == "N"
assert layout[1] == "C"
assert layout[2] == "H"
assert layout[3] == "W"
assert layout[4] == "c"
assert layout[-1] == "c"
def test_bilayout_convertible():
# not convertible
assert tvm.bijective_layout("NCHW", "ABCD") is None
# convertible
assert tvm.bijective_layout("NCHW", "NCHW16c") is not None
def test_bilayout_shape():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
assert isinstance(bilayout, tvm.tensor.BijectiveLayout)
dst_shape = bilayout.forward_shape((1, 32, 7, 7))
assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
src_shape = bilayout.backward_shape(dst_shape)
assert get_const_tuple(src_shape) == (1, 32, 7, 7)
def test_bilayout_index():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
dst_index = bilayout.forward_index([0, 18, 6, 6])
assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2)
src_index = bilayout.backward_index([0, 1, 6, 6, 2])
assert get_const_tuple(src_index) == (0, 18, 6, 6)
if __name__ == "__main__":
test_layout()
test_bilayout_convertible()
test_bilayout_shape()
test_bilayout_index()
......@@ -450,28 +450,5 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
return tvm::compute(output_shape, l, name, tag);
}
using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
/*!
* \brief Transform the layout according to the mapping function \p to_src_indices.
* \param src the source input.
* \param dst_shape the output shape.
* \param to_src_indices the mapping function from input index to output index.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape \p dst_shape.
*/
inline Tensor layout_transform(const Tensor& src,
const Array<Expr>& dst_shape,
const FLayoutIndicesTransform& to_src_indices,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
auto src_shape = src->shape;
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
return src(to_src_indices(dst_indices));
}, name, tag);
}
} // namespace topi
#endif // TOPI_NN_H_
......@@ -16,6 +16,7 @@
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "tvm/tvm.h"
#include "tvm/data_layout.h"
namespace topi {
using namespace tvm;
......@@ -882,5 +883,43 @@ inline Tensor arange(const Expr start,
}, name, tag);
}
/*!
* \brief Transform the layout according to \p src_layout and \p dst_layout
* \param src the source input.
* \param src_layout the source layout.
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor layout_transform(const Tensor& src,
const std::string& src_layout,
const std::string& dst_layout,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
Layout src_layout_struct = LayoutNode::make(src_layout);
Layout dst_layout_struct = LayoutNode::make(dst_layout);
if (src_layout_struct.Equals(dst_layout_struct)) {
return src;
}
CHECK(src_layout_struct.defined() && dst_layout_struct.defined())
<< "cannot convert from/to undefined layout";
auto layout_converter = BijectiveLayoutNode::make(src_layout_struct, dst_layout_struct);
CHECK(layout_converter.defined())
<< "cannot convert from " << src_layout << " to " << dst_layout;
Array<Expr> dst_shape = layout_converter.ForwardShape(src->shape);
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
Array<Expr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<Expr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
return src(src_indices);
}, name, tag);
}
} // namespace topi
#endif // TOPI_TRANSFORM_H_
......@@ -318,3 +318,20 @@ def arange(start, stop=None, step=1, dtype="float32"):
stop = start
start = 0
return cpp.arange(start, stop, step, dtype)
def layout_transform(array, src_layout, dst_layout):
"""Transform the layout according to src_layout and dst_layout
Parameters
----------
array : tvm.Tensor
The source array.
src_layout : str
the source layout.
dst_layout : str
the destination layout.
"""
return cpp.layout_transform(array, src_layout, dst_layout)
......@@ -272,6 +272,11 @@ TVM_REGISTER_GLOBAL("topi.split")
}
});
TVM_REGISTER_GLOBAL("topi.layout_transform")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = layout_transform(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
......
......@@ -449,6 +449,34 @@ def test_arange():
verify_arange(20, 1, -1.5)
def test_layout_transform():
in_shape = (1, 32, 8, 8)
A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
B = topi.layout_transform(A, "NCHW", "NCHW16c")
input = np.random.uniform(size=in_shape).astype(A.dtype)
output = np.transpose(input, axes=(0, 2, 3, 1))
output = np.reshape(output, newshape=(1, 8, 8, 2, 16))
output = np.transpose(output, axes=(0, 3, 1, 2, 4))
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
tvm_input = tvm.nd.array(input, ctx)
tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype)
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
f = tvm.build(s, [A, B], device, name="layout_transform")
f(tvm_input, tvm_output)
tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
for backend in get_all_backend():
check_device(backend)
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
......@@ -462,3 +490,4 @@ if __name__ == "__main__":
test_take()
test_gather_nd()
test_arange()
test_layout_transform()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment