Commit 2c231b5a by Lianmin Zheng Committed by Tianqi Chen

[RELAY] Move Layout to tvm Node system (#2125)

parent 2c7d2d78
...@@ -85,7 +85,7 @@ class Var : public HalideIR::VarExpr { ...@@ -85,7 +85,7 @@ class Var : public HalideIR::VarExpr {
/*! /*!
* \brief Container of constant ineteger (IntImm). * \brief Container of constant integer (IntImm).
* *
* This is used to store and automate type check * This is used to store and automate type check
* attributes that must be constant integer. * attributes that must be constant integer.
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
#include "../nn/layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -25,7 +25,7 @@ bool ResizeRel(const Array<Type>& types, ...@@ -25,7 +25,7 @@ bool ResizeRel(const Array<Type>& types,
const ResizeAttrs* param = attrs.as<ResizeAttrs>(); const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->layout); const Layout in_layout(param->layout);
CHECK(in_layout.convertible(kNCHW)) CHECK(in_layout.Convertible(kNCHW))
<< "Resize only support input layouts that are convertible from NCHW." << "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/relay/op/layout.cc
* \brief Layout expression.
*/
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(LayoutNode);
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout) {
CHECK_EQ(src_layout.ndim(), src.size());
if (src_layout == dst_layout) {
return src;
} else if (!src_layout.defined()) {
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
} else if (!dst_layout.defined()) {
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
}
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from "
<< src_layout << " to " << dst_layout;
std::vector<IndexExpr> dst(dst_layout.ndim());
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_dim = src_layout[i];
if (Layout::IsSuperdim(src_dim)) {
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_dim));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_dim));
int src_minor_pos = src_layout.Indexof(Layout::ToSubdim(src_dim));
int src_factor = src_layout.Subsizeof(src_dim);
int dst_factor = dst_layout.Subsizeof(src_dim);
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
src_dim_size *= src_factor;
}
dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
CHECK_LE(dst_factor, const_src_dim_size[0])
<< "Converting " << Array<IndexExpr>(src)
<< " from " << src_layout
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
}
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
}
}
}
return dst;
}
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout) {
std::vector<IndexExpr> ret(src.size());
for (size_t i = 0; i < src.size(); ++i) {
ret[i] = src[i];
}
return ConvertLayout(ret, src_layout, dst_layout);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file relay/op/layout.h
* \brief Layout expression.
*
* This file is adapted from its nnvm counterpart and will keep involving
* to the new layout system
*
* The layout is composed of upper cases, lower cases and numbers,
* where upper case indicates a (super-)dimension and
* the corresponding lower case with factor size indicates the split (sub-)dimension.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
*/
#ifndef TVM_RELAY_OP_LAYOUT_H_
#define TVM_RELAY_OP_LAYOUT_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/relay/base.h>
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
namespace tvm {
namespace relay {
class LayoutNode : public Node {
public:
std::string name;
Array<Integer> superdim_pos;
Array<Integer> subdim_pos;
Array<Integer> subdim_size;
Array<Integer> layout_simplified;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("superdim_pos", &superdim_pos);
v->Visit("subdim_pos", &subdim_pos);
v->Visit("subdim_size", &subdim_size);
v->Visit("layout_simplified", &layout_simplified);
}
static constexpr const char* _type_key = "Layout";
TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
};
class Layout : public NodeRef {
public:
using LayoutDim = char;
static constexpr uint32_t kUniqueDim = 26;
explicit Layout(NodePtr<Node> n) : NodeRef(n) {}
/*! \brief default constructor */
Layout() : Layout("__undef__") {} // NOLINT(*)
/*! \brief construct from a string */
Layout(const char* str) : Layout(std::string(str)) {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param layout input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
Layout(const std::string& layout) { // NOLINT(*)
if (layout.length() != 0) {
Parse(layout);
} else {
Parse("__undef__");
}
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
const LayoutNode* operator->() const {
return static_cast<const LayoutNode*>(node_.get());
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
LayoutNode* operator->() {
return static_cast<LayoutNode*>(node_.get());
}
/*!
* \brief Check whether a given dimension is a super-dimension.
* \param dim input dimension
* \return Whether a given dimension is a super-dimension.
*/
static bool IsSuperdim(LayoutDim dim) {
return dim >= 'A' && dim <= 'Z';
}
/*!
* \brief Check whether a given dimension is a sub-dimension.
* \param dim input dimension
* \return Whether a given dimension is a sub-dimension.
*/
static bool IsSubdim(LayoutDim dim) {
return dim >= 'a' && dim <= 'z';
}
/*!
* \brief Convert a given dimension to super-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim ToSuperdim(LayoutDim dim) {
if (IsSubdim(dim)) {
return dim - 'a' + 'A';
}
return dim;
}
/*!
* \brief Convert a given dimension to sub-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim ToSubdim(LayoutDim dim) {
if (IsSuperdim(dim)) {
return dim - 'A' + 'a';
}
return dim;
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Two layouts are convertible only if
* they have same set of super-dimensions.
* e.g., NCHW, NCHW16c, NHWC are convertible between each other,
* but NCHW, CHW, OIHW are not.
* \param dst the target layout
* \return Whether can be converted to dst layout.
*/
bool Convertible(const Layout &dst) const {
const LayoutNode *n = operator->();
if (!this->defined() || !dst.defined()) return false;
for (size_t i = 0; i < kUniqueDim; ++i) {
if ((n->superdim_pos[i]->value >= 0 && dst->superdim_pos[i]->value < 0) ||
(n->superdim_pos[i]->value < 0 && dst->superdim_pos[i]->value >= 0)) {
return false;
}
}
return true;
}
/*!
* \brief Returns a sublayout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
Layout Sublayout(size_t pos, size_t len) const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
if (len == 0) return Layout::Undef();
std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) {
if (IsSubdim(layout_simplified[i]->value)) {
auto block_size = this->Subsizeof(layout_simplified[i]->value);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified[i]->value;
}
return Layout(new_layout.str());
}
/*! \return A newly constructed reversed Layout object. */
Layout Reverse() const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
if (!this->defined()) return Layout::Undef();
std::ostringstream new_layout;
for (int64_t i = this->ndim() - 1; i >= 0; --i) {
if (IsSubdim(layout_simplified[i]->value)) {
auto block_size = this->Subsizeof(layout_simplified[i]->value);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified[i]->value;
}
return Layout(new_layout.str());
}
/*!
* \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
* \param dim The source dimension to be split. It must be a super-dimension.
* \param target_pos The target position of the newly split sub-dimension.
* \param size size of the sub-dimension.
* \return A newly constructed Layout object.
*/
Layout Split(LayoutDim dim, size_t target_pos, uint32_t size) const {
const std::string &name = operator->()->name;
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name;
CHECK(IsSuperdim(dim)) << "Cannot split a sub-dimension " << dim;
CHECK(this->Contains(dim)) << "Axis " << dim << " does not exist in " << name;
CHECK(!this->Contains(ToSubdim(dim))) << "Dimension " << dim
<< " has already been split in "
<< name;
CHECK(size > 0) << "Invalid split size " << size;
std::ostringstream new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout << size << Layout::ToSubdim(dim);
}
if (i == this->ndim()) break;
new_layout << this->at(i);
}
Layout x(new_layout.str());
return x;
}
/*! \return number of dimensions */
size_t ndim() const {
return operator->()->layout_simplified.size();
}
/*!
* \brief The description of the \p i-th dimension.
* If it is a sub-dimension, the size will be returned as well,
* e.g., 16c. Otherwise a single character is returned, e.g., C.
* \param i The position
* \return the description of the dimension.
*/
std::string at(size_t i) const {
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
CHECK_LT(i, this->ndim()) << "position " << i
<< " exceeds ndim=" << this->ndim();
std::ostringstream repr;
if (IsSubdim(layout_simplified[i]->value)) {
auto factor = Subsizeof(layout_simplified[i]->value);
CHECK_GT(factor, 0);
repr << factor;
}
repr << static_cast<char>(layout_simplified[i]->value);
return repr.str();
}
/*!
* \brief return the index of the input dimension.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param dim the input dimension.
* \return the index or -1 if not found.
*/
int32_t Indexof(LayoutDim dim) const {
if (!this->defined()) return -1;
else if (IsSuperdim(dim)) return operator->()->superdim_pos[dim - 'A']->value;
else if (IsSubdim(dim)) return operator->()->subdim_pos[dim - 'a']->value;
return -1;
}
/*!
* \param dim the input super-dimension or sub-dimension.
* \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
* or the size of \p dim itself (if \p dim is a sub-dimension).
* Return -1 if \p dim is not in the layout or the layout is undefined.
*/
int64_t Subsizeof(LayoutDim dim) const {
CHECK(IsSuperdim(dim) || IsSubdim(dim)) << "Invalid dim " << dim;
if (!this->defined() || !this->Contains(ToSubdim(dim))) {
return -1;
}
int idx = ToSubdim(dim) - 'a';
return operator->()->subdim_size[idx]->value;
}
/*!
* \brief Whether the layout contains a dimension.
* \param dim dimension to be checked.
* \return Whether the layout contains the dimension.
*/
bool Contains(LayoutDim dim) const {
if (IsSuperdim(dim)) {
return operator->()->superdim_pos[dim-'A']->value >= 0;
} else if (IsSubdim(dim)) {
return operator->()->subdim_pos[dim-'a']->value >= 0;
}
return false;
}
LayoutDim operator[](size_t i) const {
return operator->()->layout_simplified[i];
}
/*! \return whether the layout is defined */
bool defined() const {
return operator->()->name != "__undef__";
}
/*! \return the string description of the layout */
const std::string& name() const {
return operator->()->name;
}
/*!
* \brief Whether the two layouts are equal.
* \param rhs Another layout.
* \return whether the two layouts are equal.
*/
bool Equals(const Layout &rhs) const {
return operator->()->name == rhs->name;
}
using ContainerType = LayoutNode;
private:
void Parse(const std::string &layout) {
node_ = make_node<LayoutNode>();
std::vector<uint32_t> superdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_size(kUniqueDim, -1);
std::vector<char> layout_simplified;
if (layout != "__undef__") { // parse layout string
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < layout.size(); ++i) {
const LayoutDim c = layout.at(i);
if (IsSuperdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << layout
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
superdim_pos[pos] = curr++;
layout_simplified.push_back(c);
} else if (IsSubdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
subdim_pos[pos] = curr++;
subdim_size[pos] = factor;
layout_simplified.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
CHECK(!layout_simplified.empty()) << "Invalid layout " << layout;
for (LayoutDim dim : layout_simplified) {
CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0)
<< "Invalid layout " << layout << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
LayoutNode *node = operator->();
node->name = layout;
for (uint32_t i = 0; i < kUniqueDim; ++i) {
node->superdim_pos.push_back(superdim_pos[i]);
node->subdim_pos.push_back(subdim_pos[i]);
node->subdim_size.push_back(subdim_size[i]);
}
for (LayoutDim dim : layout_simplified) {
node->layout_simplified.push_back(dim);
}
}
};
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout);
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_LAYOUT_H_
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "layout.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -28,16 +29,16 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -28,16 +29,16 @@ bool Conv2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout); const Layout kernel_layout(param->weight_layout);
CHECK(in_layout.convertible(kNCHW)) CHECK(in_layout.Convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW." << "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW)) CHECK(kernel_layout.Convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW." << "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
Layout out_layout(param->out_layout); Layout out_layout(param->out_layout);
if (!out_layout.defined()) out_layout = in_layout; if (!out_layout.defined()) out_layout = in_layout;
CHECK(out_layout.convertible(kNCHW)) CHECK(out_layout.Convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW." << "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout; << " But got " << out_layout;
...@@ -55,7 +56,7 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -55,7 +56,7 @@ bool Conv2DRel(const Array<Type>& types,
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout); wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape[kernel_layout.indexof('O')] *= param->groups; wshape[kernel_layout.Indexof('O')] *= param->groups;
channels = param->channels; channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
...@@ -177,10 +178,10 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -177,10 +178,10 @@ bool Conv2DTransposeRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout); const Layout kernel_layout(param->weight_layout);
CHECK(in_layout.convertible(kNCHW)) CHECK(in_layout.Convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW." << "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW)) CHECK(kernel_layout.Convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW." << "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
......
/*!
* Copyright (c) 2018 by Contributors
* \file relay/op/nn/layout.h
* \brief Layout expression.
*
* This file is adapted from its nnvm counterpart and will keep involving
* to the new layout system
*
* The layout is composed of upper cases, lower cases and numbers,
* where upper case indicates a (super-)dimension and
* the corresponding lower case with factor size indicates the split (sub-)dimension.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
*/
#ifndef TVM_RELAY_OP_NN_LAYOUT_H_
#define TVM_RELAY_OP_NN_LAYOUT_H_
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
namespace tvm {
namespace relay {
/*! \brief layout auxiliary structure */
class Layout {
public:
using LayoutDim = char;
/*! \brief default constructor */
Layout() : name_("__undef__") {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param layout input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
Layout(const std::string& layout) { // NOLINT(*)
if (layout.length() != 0) {
parse(layout);
} else {
parse("__undef__");
}
}
/*!
* \brief copy constructor from another layout
* \param s the source layout
*/
Layout(const Layout& s) { // NOLINT(*)
this->parse(s.name_);
}
/*!
* \brief move constructor from Layout
* \param src the source layout
*/
Layout(Layout&& src) { // NOLINT(*)
this->swap(src);
}
/*!
* \brief assignment from another layout.
* \param src source layout
* \return reference of self
*/
Layout& operator=(const Layout& src) {
this->parse(src.name_);
return *this;
}
/*!
* \brief assignment from rvalue of another layout.
* \param src source layout
* \return reference of self
*/
Layout& operator=(Layout&& src) {
Layout(std::move(src)).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief assignment from string.
* \param src source layout
* \return reference of self
*/
Layout& operator=(const std::string& src) {
this->parse(src);
return *this;
}
/*!
* \return whether two layout equals
* \param s the layout to compare against
*/
bool operator==(const Layout& s) const {
return name_ == s.name_;
}
/*!
* \return whether two layout not equal
* \param s the layout to compare against
*/
bool operator!=(const Layout& s) const {
return !(*this == s);
}
/*!
* \brief Append the current layout by another.
* @param other the layout to be appended
* @return a new layout
*/
Layout operator+(const Layout& other) const {
if (!this->defined() && !other.defined()) {
return Layout::Undef();
} else if (!this->defined()) {
return other;
} else if (!other.defined()) {
return *this;
}
return Layout(this->name_ + other.name_);
}
/*!
* \brief Check whether a given dimension is a super-dimension.
* \param dim input dimension
* \return Whether a given dimension is a super-dimension.
*/
static bool is_superdim(LayoutDim dim) {
return dim >= 'A' && dim <= 'Z';
}
/*!
* \brief Check whether a given dimension is a sub-dimension.
* \param dim input dimension
* \return Whether a given dimension is a sub-dimension.
*/
static bool is_subdim(LayoutDim dim) {
return dim >= 'a' && dim <= 'z';
}
/*!
* \brief Convert a given dimension to super-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim to_superdim(LayoutDim dim) {
if (is_subdim(dim)) {
return dim - 'a' + 'A';
}
return dim;
}
/*!
* \brief Convert a given dimension to sub-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim to_subdim(LayoutDim dim) {
if (is_superdim(dim)) {
return dim - 'A' + 'a';
}
return dim;
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Swap current object with other
* \param other another object to be swapped.
*/
void swap(Layout& other) { // NOLINT(*)
std::swap(name_, other.name_);
std::swap(superdim_pos_, other.superdim_pos_);
std::swap(subdim_pos_, other.subdim_pos_);
std::swap(subdim_size_, other.subdim_size_);
std::swap(layout_simplified_, other.layout_simplified_);
}
/*!
* \brief Two layouts are convertible only if
* they have same set of super-dimensions.
* e.g., NCHW, NCHW16c, NHWC are convertible between each other,
* but NCHW, CHW, OIHW are not.
* \param dst the target layout
* \return Whether can be converted to dst layout.
*/
bool convertible(const Layout &dst) const {
if (!this->defined() || !dst.defined()) return false;
for (size_t i = 0; i < kUniqueDim; ++i) {
if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
(superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
return false;
}
}
return true;
}
/*!
* \brief Returns a sublayout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
Layout sublayout(size_t pos, size_t len) const {
if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
if (len == 0) return Layout::Undef();
std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) {
if (is_subdim(layout_simplified_[i])) {
auto block_size = this->subsizeof(layout_simplified_[i]);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified_[i];
}
return Layout(new_layout.str());
}
/*! \return A newly constructed reversed Layout object. */
Layout reverse() const {
if (!this->defined()) return Layout::Undef();
std::ostringstream new_layout;
for (int64_t i = this->ndim() - 1; i >= 0; --i) {
if (is_subdim(layout_simplified_[i])) {
auto block_size = this->subsizeof(layout_simplified_[i]);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified_[i];
}
return Layout(new_layout.str());
}
/*!
* \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
* \param dim The source dimension to be split. It must be a super-dimension.
* \param target_pos The target position of the newly split sub-dimension.
* \param size size of the sub-dimension.
* \return A newly constructed Layout object.
*/
Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name_;
CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim
<< " has already been split in "
<< name_;
CHECK(size > 0) << "Invalid split size " << size;
std::ostringstream new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout << size << Layout::to_subdim(dim);
}
if (i == this->ndim()) break;
new_layout << this->at(i);
}
Layout x(new_layout.str());
return x;
}
using iterator = std::vector<LayoutDim>::const_iterator;
using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
/*! \return begin iterator */
iterator begin() const {
return layout_simplified_.begin();
}
/*! \return end iterator */
iterator end() const {
return layout_simplified_.end();
}
/*! \return rbegin iterator */
reverse_iterator rbegin() const {
return layout_simplified_.rbegin();
}
/*! \return rend iterator */
reverse_iterator rend() const {
return layout_simplified_.rend();
}
/*! \return number of dimensions */
size_t ndim() const {
return layout_simplified_.size();
}
/*!
* \brief The description of the \p i-th dimension.
* If it is a sub-dimension, the size will be returned as well,
* e.g., 16c. Otherwise a single character is returned, e.g., C.
* \param i The position
* \return the description of the dimension.
*/
std::string at(size_t i) const {
CHECK_LT(i, this->ndim()) << "position " << i
<< " exceeds ndim=" << this->ndim();
std::ostringstream repr;
if (is_subdim(layout_simplified_[i])) {
auto factor = subsizeof(layout_simplified_[i]);
CHECK_GT(factor, 0);
repr << factor;
}
repr << layout_simplified_[i];
return repr.str();
}
/*!
* \brief return the index of the input dimension.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param dim the input dimension.
* \return the index or -1 if not found.
*/
int32_t indexof(LayoutDim dim) const {
if (!this->defined()) return -1;
else if (is_superdim(dim)) return superdim_pos_[dim - 'A'];
else if (is_subdim(dim)) return subdim_pos_[dim - 'a'];
return -1;
}
/*!
* \param dim the input super-dimension or sub-dimension.
* \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
* or the size of \p dim itself (if \p dim is a sub-dimension).
* Return -1 if \p dim is not in the layout or the layout is undefined.
*/
int64_t subsizeof(LayoutDim dim) const {
CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim;
if (!this->defined() || !this->contains(to_subdim(dim))) {
return -1;
}
int idx = to_subdim(dim) - 'a';
return subdim_size_[idx];
}
/*!
* \brief Whether the layout contains a dimension.
* \param dim dimension to be checked.
* \return Whether the layout contains the dimension.
*/
bool contains(LayoutDim dim) const {
if (is_superdim(dim)) {
return superdim_pos_[dim-'A'] >= 0;
} else if (is_subdim(dim)) {
return subdim_pos_[dim-'a'] >= 0;
}
return false;
}
LayoutDim operator[](size_t i) const {
return layout_simplified_[i];
}
/*! \return whether the layout is defined */
bool defined() const {
return name_ != "__undef__";
}
/*! \return the string description of the layout */
const std::string& name() const {
return name_;
}
/*!
* \brief Write layout in JSON format.
* \param writer JSONWriter
*/
void Save(dmlc::JSONWriter* writer) const {
writer->Write(name_);
}
/*!
* \brief Load layout from JSON.
* \param reader JSONReader
*/
void Load(dmlc::JSONReader* reader) {
std::string tmp;
reader->Read(&tmp);
this->parse(tmp);
}
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name_;
return os;
}
private:
static const uint32_t kUniqueDim = 26;
std::string name_;
int32_t superdim_pos_[kUniqueDim];
int32_t subdim_pos_[kUniqueDim];
int64_t subdim_size_[kUniqueDim];
std::vector<LayoutDim> layout_simplified_;
void parse(const std::string& layout) {
name_ = layout;
std::fill_n(superdim_pos_, kUniqueDim, -1);
std::fill_n(subdim_pos_, kUniqueDim, -1);
std::fill_n(subdim_size_, kUniqueDim, -1);
layout_simplified_.clear();
if (layout == "__undef__") return;
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < layout.size(); ++i) {
const LayoutDim c = layout.at(i);
if (is_superdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << layout
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
superdim_pos_[pos] = curr++;
layout_simplified_.push_back(c);
} else if (is_subdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
subdim_pos_[pos] = curr++;
subdim_size_[pos] = factor;
layout_simplified_.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
for (LayoutDim dim : layout_simplified_) {
CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
<< "Invalid layout " << layout << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
};
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
inline std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout) {
CHECK_EQ(src_layout.ndim(), src.size());
if (src_layout == dst_layout) {
return src;
} else if (!src_layout.defined()) {
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
} else if (!dst_layout.defined()) {
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
}
CHECK(src_layout.convertible(dst_layout))
<< "cannot convert from "
<< src_layout << " to " << dst_layout;
std::vector<IndexExpr> dst(dst_layout.ndim());
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_dim = src_layout[i];
if (Layout::is_superdim(src_dim)) {
int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_dim));
int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_dim));
int src_minor_pos = src_layout.indexof(Layout::to_subdim(src_dim));
int src_factor = src_layout.subsizeof(src_dim);
int dst_factor = dst_layout.subsizeof(src_dim);
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
src_dim_size *= src_factor;
}
dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
CHECK_LE(dst_factor, const_src_dim_size[0])
<< "Converting " << Array<IndexExpr>(src)
<< " from " << src_layout
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
}
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
}
}
}
return dst;
}
inline std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout) {
std::vector<IndexExpr> ret(src.size());
for (size_t i = 0; i < src.size(); ++i) {
ret[i] = src[i];
}
return ConvertLayout(ret, src_layout, dst_layout);
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_LAYOUT_H_
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include <vector> #include <vector>
#include "../type_relations.h" #include "../type_relations.h"
#include "../op_common.h" #include "../op_common.h"
#include "layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h> #include <topi/nn/pooling.h>
#include <vector> #include <vector>
#include "layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -33,13 +33,13 @@ bool Pool2DRel(const Array<Type>& types, ...@@ -33,13 +33,13 @@ bool Pool2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.contains('H') && layout.contains('W') && CHECK(layout.Contains('H') && layout.Contains('W') &&
!layout.contains('h') && !layout.contains('w')) !layout.Contains('h') && !layout.Contains('w'))
<< "Invalid layout " << layout << "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split"; << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.indexof('H'); const auto hidx = layout.Indexof('H');
const auto widx = layout.indexof('W'); const auto widx = layout.Indexof('W');
IndexExpr pad_h, pad_w; IndexExpr pad_h, pad_w;
if (param->padding.size() == 1) { if (param->padding.size() == 1) {
...@@ -102,10 +102,10 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -102,10 +102,10 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
auto padding = param->padding; auto padding = param->padding;
auto ceil_mode = param->ceil_mode; auto ceil_mode = param->ceil_mode;
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.convertible(Layout("NCHW"))) CHECK(layout.Convertible(Layout("NCHW")))
<< "max_pool2d currently only supports layouts that are convertible from NCHW"; << "max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.indexof('h'), -1) << "max_pool2d does not support input split on height"; CHECK_EQ(layout.Indexof('h'), -1) << "max_pool2d does not support input split on height";
CHECK_EQ(layout.indexof('w'), -1) << "max_pool2d does not support input split on width"; CHECK_EQ(layout.Indexof('w'), -1) << "max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)" << "Pool2D only support 4-D input (e.g., NCHW)"
...@@ -240,13 +240,13 @@ bool GlobalPool2DRel(const Array<Type>& types, ...@@ -240,13 +240,13 @@ bool GlobalPool2DRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.contains('H') && layout.contains('W') && CHECK(layout.Contains('H') && layout.Contains('W') &&
!layout.contains('h') && !layout.contains('w')) !layout.Contains('h') && !layout.Contains('w'))
<< "Invalid layout " << layout << "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split"; << ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.indexof('H'); const auto hidx = layout.Indexof('H');
const auto widx = layout.indexof('W'); const auto widx = layout.Indexof('W');
std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]}); std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]});
oshape[hidx] = oshape[widx] = 1; oshape[hidx] = oshape[widx] = 1;
...@@ -264,11 +264,11 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs, ...@@ -264,11 +264,11 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
const auto* param = attrs.as<GlobalPool2DAttrs>(); const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout layout(param->layout); Layout layout(param->layout);
CHECK(layout.convertible(Layout("NCHW"))) CHECK(layout.Convertible(Layout("NCHW")))
<< "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; << "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.indexof('h'), -1) CHECK_EQ(layout.Indexof('h'), -1)
<< "global_avg_pool2d does not support input split on height"; << "global_avg_pool2d does not support input split on height";
CHECK_EQ(layout.indexof('w'), -1) CHECK_EQ(layout.Indexof('w'), -1)
<< "global_avg_pool2d does not support input split on width"; << "global_avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include "layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -25,7 +25,7 @@ bool UpSamplingRel(const Array<Type>& types, ...@@ -25,7 +25,7 @@ bool UpSamplingRel(const Array<Type>& types,
const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>(); const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->layout); const Layout in_layout(param->layout);
CHECK(in_layout.convertible(kNCHW)) CHECK(in_layout.Convertible(kNCHW))
<< "UpSampling only support input layouts that are convertible from NCHW." << "UpSampling only support input layouts that are convertible from NCHW."
<< " But got " << in_layout; << " But got " << in_layout;
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "pattern_util.h" #include "pattern_util.h"
#include "pass_util.h" #include "pass_util.h"
#include "../op/nn/layout.h" #include "../op/layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -378,8 +379,8 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { ...@@ -378,8 +379,8 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout); Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C'); int c_big_axis = data_layout.Indexof('C');
int c_small_axis = data_layout.indexof('c'); int c_small_axis = data_layout.Indexof('c');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
AxesSet data_axes = NullValue<AxesSet>(); AxesSet data_axes = NullValue<AxesSet>();
...@@ -391,7 +392,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { ...@@ -391,7 +392,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
if (weight_layout.indexof('i') < 0 && if (weight_layout.Indexof('i') < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis}; data_axes = {c_big_axis};
...@@ -412,15 +413,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call, ...@@ -412,15 +413,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout); Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C'); int c_big_axis = data_layout.Indexof('C');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.indexof('i'), -1); CHECK_EQ(weight_layout.Indexof('i'), -1);
CHECK(sdata->axes.size() == 1 && CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value); c_big_axis == sdata->axes[0]->value);
int big_oc_axis = weight_layout.indexof('O'); int big_oc_axis = weight_layout.Indexof('O');
int big_ic_axis = weight_layout.indexof('I'); int big_ic_axis = weight_layout.Indexof('I');
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout);
...@@ -779,8 +780,8 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { ...@@ -779,8 +780,8 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
out_layout = Layout(param->data_layout); out_layout = Layout(param->data_layout);
} }
Layout weight_layout(param->weight_layout); Layout weight_layout(param->weight_layout);
int c_big_axis = out_layout.indexof('C'); int c_big_axis = out_layout.Indexof('C');
int c_small_axis = out_layout.indexof('c'); int c_small_axis = out_layout.Indexof('c');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
...@@ -791,8 +792,8 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { ...@@ -791,8 +792,8 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
if (weight_layout.indexof('o') < 0 && if (weight_layout.Indexof('o') < 0 &&
weight_layout.indexof('i') < 0 && weight_layout.Indexof('i') < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
return {c_big_axis}; return {c_big_axis};
...@@ -816,16 +817,16 @@ Expr Conv2DBackwardTransform(const Call& call, ...@@ -816,16 +817,16 @@ Expr Conv2DBackwardTransform(const Call& call,
out_layout = Layout(param->data_layout); out_layout = Layout(param->data_layout);
} }
Layout weight_layout(param->weight_layout); Layout weight_layout(param->weight_layout);
int c_big_axis = out_layout.indexof('C'); int c_big_axis = out_layout.Indexof('C');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.indexof('o'), -1); CHECK_EQ(weight_layout.Indexof('o'), -1);
CHECK_EQ(weight_layout.indexof('i'), -1); CHECK_EQ(weight_layout.Indexof('i'), -1);
CHECK(axes.size() == 1 && CHECK(axes.size() == 1 &&
c_big_axis == axes[0]->value); c_big_axis == axes[0]->value);
int big_oc_axis = weight_layout.indexof('O'); int big_oc_axis = weight_layout.Indexof('O');
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d); CHECK(param->groups == 1 || is_depthwise_conv2d);
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include "../op/nn/layout.h" #include "../op/layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment