Commit 9f8fcfc9 by Yizhi Liu Committed by Tianqi Chen

General Layout Support (#447)

parent fc7e9cd2
/*!
* Copyright (c) 2017 by Contributors
* \file contrib_op_param.h
* \brief Additional parameters for compiler optimized operators.
*/
#ifndef NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#define NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#include <dmlc/parameter.h>
#include <string>
namespace nnvm {
namespace compiler {
/*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout);
DMLC_DECLARE_FIELD(dst_layout);
}
};
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_CONTRIB_OP_PARAM_H_
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include "packed_func_ext.h"
namespace nnvm { namespace nnvm {
namespace compiler { namespace compiler {
...@@ -73,19 +74,17 @@ using FTVMSchedule = std::function< ...@@ -73,19 +74,17 @@ using FTVMSchedule = std::function<
const Array<Tensor>& outs, const Array<Tensor>& outs,
const std::string& target)>; const std::string& target)>;
/*! \brief Layout Information about an entry */
using TLayoutInfo = std::string;
/*! /*!
* \brief The producer consumer function of node layout * \brief Modify the op node to alter its input layout.
* \param attrs The attribute of the node. * it is invoked in AlterOpLayout pass.
* \param ilayouts The input layouts that the node request. * \param attrs The attribute of the original node.
* \param olayouts The output layouts that the node produce. * \param inputs The input symbols of the original node.
* \return bool The success flag. * \param tinfos The inferred shape and dtype of the inputs.
*/ */
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs, using FTVMAlterOpLayout = std::function<
std::vector<TLayoutInfo> *ilayouts, Symbol(const NodeAttrs& attrs,
std::vector<TLayoutInfo> *olayouts)>; const Symbol& inputs,
const Array<Tensor>& tinfos)>;
/*! /*!
* \brief Transform from normal operator to vectorized operator * \brief Transform from normal operator to vectorized operator
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <nnvm/symbolic.h> #include <nnvm/symbolic.h>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
namespace nnvm { namespace nnvm {
...@@ -52,6 +53,7 @@ template<> ...@@ -52,6 +53,7 @@ template<>
struct extension_class_info<nnvm::compiler::AttrDict> { struct extension_class_info<nnvm::compiler::AttrDict> {
static const int code = 18; static const int code = 18;
}; };
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_ #endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "./tuple.h" #include "./tuple.h"
#include "./layout.h"
namespace nnvm { namespace nnvm {
...@@ -46,7 +47,7 @@ using ShapeVector = std::vector<TShape>; ...@@ -46,7 +47,7 @@ using ShapeVector = std::vector<TShape>;
* \code * \code
* Graph g = ApplyPass(src_graph, "InferType"); * Graph g = ApplyPass(src_graph, "InferType");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype"); * const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id * // get type by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; * int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* \endcode * \endcode
* *
...@@ -55,6 +56,21 @@ using ShapeVector = std::vector<TShape>; ...@@ -55,6 +56,21 @@ using ShapeVector = std::vector<TShape>;
using DTypeVector = std::vector<int>; using DTypeVector = std::vector<int>;
/*! /*!
* \brief The result holder of layout of each NodeEntry in the graph.
* \note Stored under graph.attrs["layout"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, "LayoutTransform");
* const LayoutVector& layouts = g.GetAttr<LayoutVector>("layout");
* // get layout by entry id
* int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferLayout
*/
using LayoutVector = std::vector<Layout>;
/*!
* \brief The result holder of device of each operator in the graph. * \brief The result holder of device of each operator in the graph.
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice" * \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
* *
......
/*!
* Copyright (c) 2018 by Contributors
* \file layout.h
* \brief Layout expression.
* The layout is composed of upper cases, lower cases and numbers,
* where upper case indicates a (super-)dimension and
* the corresponding lower case with factor size indicates the split (sub-)dimension.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
*/
#ifndef NNVM_LAYOUT_H_
#define NNVM_LAYOUT_H_
#include <dmlc/parameter.h>
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
namespace nnvm {
class Layout {
public:
using LayoutDim = char;
/*! \brief default constructor */
Layout() : name_("__undef__") {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param layout input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
inline Layout(const std::string& layout) { // NOLINT(*)
parse(layout);
}
/*!
* \brief copy constructor from another layout
* \param s the source layout
*/
inline Layout(const Layout& s) { // NOLINT(*)
this->parse(s.name_);
}
/*!
* \brief move constructor from Layout
* \param src the source layout
*/
inline Layout(Layout&& src) { // NOLINT(*)
this->swap(src);
}
/*!
* \brief assignment from another layout.
* \param src source layout
* \return reference of self
*/
inline Layout& operator=(const Layout& src) {
this->parse(src.name_);
return *this;
}
/*!
* \brief assignment from rvalue of another layout.
* \param src source layout
* \return reference of self
*/
inline Layout& operator=(Layout&& src) {
Layout(std::move(src)).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief assignment from string.
* \param src source layout
* \return reference of self
*/
inline Layout& operator=(const std::string& src) {
this->parse(src);
return *this;
}
/*!
* \return whether two layout equals
* \param s the layout to compare against
*/
inline bool operator==(const Layout& s) const {
return name_ == s.name_;
}
/*!
* \return whether two layout not equal
* \param s the layout to compare against
*/
inline bool operator!=(const Layout& s) const {
return !(*this == s);
}
/*!
* \brief Append the current layout by another.
* @param other the layout to be appended
* @return a new layout
*/
inline Layout operator+(const Layout& other) const {
if (!this->defined() && !other.defined()) {
return Layout::Undef();
} else if (!this->defined()) {
return other;
} else if (!other.defined()) {
return *this;
}
return Layout(this->name_ + other.name_);
}
/*!
* \brief Check whether a given dimension is a super-dimension.
* \param dim input dimension
* \return Whether a given dimension is a super-dimension.
*/
static inline bool is_superdim(LayoutDim dim) {
return dim >= 'A' && dim <= 'Z';
}
/*!
* \brief Check whether a given dimension is a sub-dimension.
* \param dim input dimension
* \return Whether a given dimension is a sub-dimension.
*/
static inline bool is_subdim(LayoutDim dim) {
return dim >= 'a' && dim <= 'z';
}
/*!
* \brief Convert a given dimension to super-dimension.
* \param dim input dimension
* \return The converted description.
*/
static inline LayoutDim to_superdim(LayoutDim dim) {
if (is_subdim(dim)) {
return dim - 'a' + 'A';
}
return dim;
}
/*!
* \brief Convert a given dimension to sub-dimension.
* \param dim input dimension
* \return The converted description.
*/
static inline LayoutDim to_subdim(LayoutDim dim) {
if (is_superdim(dim)) {
return dim - 'A' + 'a';
}
return dim;
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static inline const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Swap current object with other
* \param other another object to be swapped.
*/
inline void swap(Layout& other) { // NOLINT(*)
std::swap(name_, other.name_);
std::swap(superdim_pos_, other.superdim_pos_);
std::swap(subdim_pos_, other.subdim_pos_);
std::swap(subdim_size_, other.subdim_size_);
std::swap(layout_simplified_, other.layout_simplified_);
}
/*!
* \brief Two layouts are convertible only if
* they have same set of super-dimensions.
* e.g., NCHW, NCHW16c, NHWC are convertible between each other,
* but NCHW, CHW, OIHW are not.
* \param dst the target layout
* \return Whether can be converted to dst layout.
*/
inline bool convertible(const Layout &dst) const {
if (!this->defined() || !dst.defined()) return false;
for (size_t i = 0; i < kUniqueDim; ++i) {
if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
(superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
return false;
}
}
return true;
}
/*!
* \brief Returns a sublayout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
inline Layout sublayout(size_t pos, size_t len) const {
if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
if (len == 0) return Layout::Undef();
std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) {
if (is_subdim(layout_simplified_[i])) {
auto block_size = this->subsizeof(layout_simplified_[i]);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified_[i];
}
return Layout(new_layout.str());
}
/*! \return A newly constructed reversed Layout object. */
inline Layout reverse() const {
if (!this->defined()) return Layout::Undef();
std::ostringstream new_layout;
for (int64_t i = this->ndim() - 1; i >= 0; --i) {
if (is_subdim(layout_simplified_[i])) {
auto block_size = this->subsizeof(layout_simplified_[i]);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified_[i];
}
return Layout(new_layout.str());
}
/*!
* \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
* \param dim The source dimension to be split. It must be a super-dimension.
* \param target_pos The target position of the newly split sub-dimension.
* \param size size of the sub-dimension.
* \return A newly constructed Layout object.
*/
inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name_;
CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim
<< " has already been split in "
<< name_;
CHECK(size > 0) << "Invalid split size " << size;
std::ostringstream new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout << size << Layout::to_subdim(dim);
}
if (i == this->ndim()) break;
new_layout << this->at(i);
}
Layout x(new_layout.str());
return x;
}
using iterator = std::vector<LayoutDim>::const_iterator;
using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
/*! \return begin iterator */
inline iterator begin() const {
return layout_simplified_.begin();
}
/*! \return end iterator */
inline iterator end() const {
return layout_simplified_.end();
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return layout_simplified_.rbegin();
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return layout_simplified_.rend();
}
/*! \return number of dimensions */
inline size_t ndim() const {
return layout_simplified_.size();
}
/*!
* \brief The description of the \p i-th dimension.
* If it is a sub-dimension, the size will be returned as well,
* e.g., 16c. Otherwise a single character is returned, e.g., C.
* \param i The position
* \return the description of the dimension.
*/
inline std::string at(size_t i) const {
CHECK_LT(i, this->ndim()) << "position " << i
<< " exceeds ndim=" << this->ndim();
std::ostringstream repr;
if (is_subdim(layout_simplified_[i])) {
auto factor = subsizeof(layout_simplified_[i]);
CHECK_LT(factor, 0);
repr << factor;
}
repr << layout_simplified_[i];
return repr.str();
}
/*!
* \brief return the index of the input dimension.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param dim the input dimension.
* \return the index or -1 if not found.
*/
inline int32_t indexof(LayoutDim dim) const {
if (!this->defined()) return -1;
else if (is_superdim(dim)) return superdim_pos_[dim - 'A'];
else if (is_subdim(dim)) return subdim_pos_[dim - 'a'];
return -1;
}
/*!
* \param dim the input super-dimension or sub-dimension.
* \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
* or the size of \p dim itself (if \p dim is a sub-dimension).
* Return -1 if \p dim is not in the layout or the layout is undefined.
*/
inline int64_t subsizeof(LayoutDim dim) const {
CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim;
if (!this->defined() || !this->contains(to_subdim(dim))) {
return -1;
}
int idx = to_subdim(dim) - 'a';
return subdim_size_[idx];
}
/*!
* \brief Whether the layout contains a dimension.
* \param dim dimension to be checked.
* \return Whether the layout contains the dimension.
*/
inline bool contains(LayoutDim dim) const {
if (is_superdim(dim)) {
return superdim_pos_[dim-'A'] >= 0;
} else if (is_subdim(dim)) {
return subdim_pos_[dim-'a'] >= 0;
}
return false;
}
inline const LayoutDim operator[](size_t i) const {
return layout_simplified_[i];
}
/*! \return whether the layout is defined */
inline bool defined() const {
return name_ != "__undef__";
}
/*! \return the string description of the layout */
inline const std::string& name() const {
return name_;
}
/*!
* \brief Write layout in JSON format.
* \param writer JSONWriter
*/
inline void Save(dmlc::JSONWriter* writer) const {
writer->Write(name_);
}
/*!
* \brief Load layout from JSON.
* \param reader JSONReader
*/
inline void Load(dmlc::JSONReader* reader) {
std::string tmp;
reader->Read(&tmp);
this->parse(tmp);
}
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name_;
return os;
}
private:
static const uint32_t kUniqueDim = 26;
std::string name_;
int32_t superdim_pos_[kUniqueDim];
int32_t subdim_pos_[kUniqueDim];
int64_t subdim_size_[kUniqueDim];
std::vector<LayoutDim> layout_simplified_;
void parse(const std::string& layout) {
name_ = layout;
std::fill_n(superdim_pos_, kUniqueDim, -1);
std::fill_n(subdim_pos_, kUniqueDim, -1);
std::fill_n(subdim_size_, kUniqueDim, -1);
layout_simplified_.clear();
if (layout == "__undef__") return;
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < layout.size(); ++i) {
const LayoutDim c = layout.at(i);
if (is_superdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << layout
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
superdim_pos_[pos] = curr++;
layout_simplified_.push_back(c);
} else if (is_subdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
subdim_pos_[pos] = curr++;
subdim_size_[pos] = factor;
layout_simplified_.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
for (LayoutDim dim : layout_simplified_) {
CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
<< "Invalid layout " << layout << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
};
} // namespace nnvm
#endif // NNVM_LAYOUT_H_
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "./base.h" #include "./base.h"
#include "./node.h" #include "./node.h"
#include "./tuple.h" #include "./tuple.h"
#include "./layout.h"
namespace nnvm { namespace nnvm {
...@@ -176,6 +177,31 @@ using FSetInputVarAttrOnCompose = std::function<void( ...@@ -176,6 +177,31 @@ using FSetInputVarAttrOnCompose = std::function<void(
NodePtr var, NodePtr var,
const int index)>; const int index)>;
/*!
* \brief Inference function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
* \param ilayouts Given the input layouts produced by ancestor nodes,
* it should be filled by layouts that the node requests.
* If the requested layout is different from what ancestor produces,
* a __layout_transform__ operator will be inserted automatically.
* \param last_ilayouts The input layouts requested by the node
* at the last infer pass (if any).
* This can be useful when an operator wants to keep
* the input layout the same as the original one.
* For example, after the pass of AlterOpLayout,
* transpose(input, axis=[1, 2, 3, 0]) may receive an input of NCHW16c layout,
* with which it cannot calculate with axis=[1, 2, 3, 0].
* Last input layouts allow it to know what the layout it originally inferred,
* i.e., the layout in the imported model.
* \param olayouts Inferred output layouts.
* \return success flag.
*/
using FInferLayout = std::function<bool(
const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts)>;
} // namespace nnvm } // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_ #endif // NNVM_OP_ATTR_TYPES_H_
...@@ -9,23 +9,12 @@ ...@@ -9,23 +9,12 @@
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <nnvm/tuple.h> #include <nnvm/tuple.h>
#include <nnvm/layout.h>
#include <string>
namespace nnvm { namespace nnvm {
namespace top { namespace top {
// Layout flag in spatial conv and pooling.
enum LayoutFlag {
kNCHW,
kNHWC,
kCHWN,
kNCW,
kNWC,
kCWN,
kNCDHW,
kNDHWC,
kCDHWN
};
struct DenseParam : public dmlc::Parameter<DenseParam> { struct DenseParam : public dmlc::Parameter<DenseParam> {
int units; int units;
bool use_bias; bool use_bias;
...@@ -130,7 +119,9 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> { ...@@ -130,7 +119,9 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
TShape padding; TShape padding;
TShape dilation; TShape dilation;
int groups; int groups;
int layout; std::string layout;
std::string kernel_layout;
std::string out_layout;
bool use_bias; bool use_bias;
DMLC_DECLARE_PARAMETER(Conv2DParam) { DMLC_DECLARE_PARAMETER(Conv2DParam) {
...@@ -152,14 +143,19 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> { ...@@ -152,14 +143,19 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
"At groups=2, the operation becomes equivalent to having two convolution" "At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing" "layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated."); "half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.add_enum("NCHW", kNCHW) .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions."); "'W' dimensions.");
DMLC_DECLARE_FIELD(out_layout).set_default("__undef__")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
DMLC_DECLARE_FIELD(use_bias).set_default(true) DMLC_DECLARE_FIELD(use_bias).set_default(true)
.describe("Whether the layer uses a bias vector."); .describe("Whether the layer uses a bias vector.");
} }
...@@ -178,7 +174,8 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> { ...@@ -178,7 +174,8 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> {
TShape output_padding; TShape output_padding;
TShape dilation; TShape dilation;
int groups; int groups;
int layout; std::string layout;
std::string kernel_layout;
bool use_bias; bool use_bias;
DMLC_DECLARE_PARAMETER(Conv2DTransposeParam) { DMLC_DECLARE_PARAMETER(Conv2DTransposeParam) {
...@@ -202,14 +199,15 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> { ...@@ -202,14 +199,15 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> {
"At groups=2, the operation becomes equivalent to having two convolution" "At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing" "layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated."); "half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.add_enum("NCHW", kNCHW) .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions."); "'W' dimensions.");
DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
DMLC_DECLARE_FIELD(use_bias).set_default(true) DMLC_DECLARE_FIELD(use_bias).set_default(true)
.describe("Whether the layer uses a bias vector."); .describe("Whether the layer uses a bias vector.");
} }
...@@ -224,7 +222,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> { ...@@ -224,7 +222,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
TShape pool_size; TShape pool_size;
TShape strides; TShape strides;
TShape padding; TShape padding;
int layout; std::string layout;
bool ceil_mode; bool ceil_mode;
DMLC_DECLARE_PARAMETER(Pool2DParam) { DMLC_DECLARE_PARAMETER(Pool2DParam) {
...@@ -235,10 +233,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> { ...@@ -235,10 +233,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded" .describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points"); "on both sides for padding number of points");
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
...@@ -250,13 +245,10 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> { ...@@ -250,13 +245,10 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> { struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
int layout; std::string layout;
DMLC_DECLARE_PARAMETER(GlobalPool2DParam) { DMLC_DECLARE_PARAMETER(GlobalPool2DParam) {
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
...@@ -266,15 +258,13 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> { ...@@ -266,15 +258,13 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> { struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale; int scale;
int layout; std::string layout;
DMLC_DECLARE_PARAMETER(UpSamplingParam) { DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale) DMLC_DECLARE_FIELD(scale)
.describe("upsampling scaling factor"); .describe("upsampling scaling factor");
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW) .set_default("NCHW")
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
...@@ -282,6 +272,18 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> { ...@@ -282,6 +272,18 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
} }
}; };
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout).set_default("__undef__")
.describe("Dimension ordering of data");
DMLC_DECLARE_FIELD(dst_layout).set_default("__undef__")
.describe("Dimension ordering of data.");
}
};
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
......
...@@ -211,12 +211,15 @@ def _init_symbol_module(symbol_class, root_namespace): ...@@ -211,12 +211,15 @@ def _init_symbol_module(symbol_class, root_namespace):
op_names.append(py_str(plist[i])) op_names.append(py_str(plist[i]))
module_obj = sys.modules["%s.symbol" % root_namespace] module_obj = sys.modules["%s.symbol" % root_namespace]
module_obj_contrib = sys.modules["%s.contrib" % root_namespace]
module_internal = sys.modules["%s._symbol_internal" % root_namespace] module_internal = sys.modules["%s._symbol_internal" % root_namespace]
for name in op_names: for name in op_names:
hdl = OpHandle() hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = _make_atomic_symbol_function(hdl, name) function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_'): if function.__name__.startswith('_contrib_'):
setattr(module_obj_contrib, function.__name__.split('_contrib_')[1], function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function) setattr(module_internal, function.__name__, function)
setattr(module_obj, function.__name__, function) setattr(module_obj, function.__name__, function)
else: else:
......
...@@ -15,7 +15,8 @@ OPT_PASS_LEVEL = { ...@@ -15,7 +15,8 @@ OPT_PASS_LEVEL = {
"SimplifyInference": 0, "SimplifyInference": 0,
"PrecomputePrune": 2, "PrecomputePrune": 2,
"OpFusion": 1, "OpFusion": 1,
"FoldScaleAxis": 3 "FoldScaleAxis": 3,
"AlterOpLayout": 3,
} }
# List of optimization pass and level when switch on # List of optimization pass and level when switch on
...@@ -139,7 +140,7 @@ def _update_shape_dtype(shape, dtype, params): ...@@ -139,7 +140,7 @@ def _update_shape_dtype(shape, dtype, params):
return shape, dtype return shape, dtype
def optimize(graph, shape, dtype="float32"): def optimize(graph, shape, dtype="float32", layout=None):
"""Perform target and parameter invariant graph optimization. """Perform target and parameter invariant graph optimization.
This is an advanced function that usually do not need to be called. This is an advanced function that usually do not need to be called.
...@@ -157,6 +158,18 @@ def optimize(graph, shape, dtype="float32"): ...@@ -157,6 +158,18 @@ def optimize(graph, shape, dtype="float32"):
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
cfg = BuildConfig.current cfg = BuildConfig.current
if cfg.pass_enabled("AlterOpLayout"):
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply(["CorrectLayout"])
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply(["InferShape", "InferType", "AlterOpLayout"])
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply(["CorrectLayout"])
if cfg.pass_enabled("SimplifyInference"): if cfg.pass_enabled("SimplifyInference"):
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyInference"]) graph = graph.apply(["InferShape", "SimplifyInference"])
...@@ -167,7 +180,8 @@ def optimize(graph, shape, dtype="float32"): ...@@ -167,7 +180,8 @@ def optimize(graph, shape, dtype="float32"):
return graph return graph
def build(graph, target=None, shape=None, dtype="float32", params=None, target_host=None): def build(graph, target=None, shape=None, dtype="float32",
params=None, target_host=None, layout=None):
"""Build graph into runtime library. """Build graph into runtime library.
The build function will optimize the graph and do the compilation. The build function will optimize the graph and do the compilation.
...@@ -204,8 +218,8 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -204,8 +218,8 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
By default, llvm is used if it is enabled, By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used. otherwise a stackvm intepreter is used.
initialize : bool, optional layout : dict of str to str or str optional
Whether to initialize variables in global dict _all_var_init. The input layout
Returns Returns
------- -------
...@@ -230,6 +244,15 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -230,6 +244,15 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
cfg = BuildConfig.current cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
# correct layout if necessary
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply("CorrectLayout")
index = graph.index
layouts = graph.json_attr("layout")
layout = {x : layouts[index.entry_id(x)] for x in index.input_names}
# Initial pass do shape type inference # Initial pass do shape type inference
ishape, _ = graph_util.infer_shape(graph, **shape) ishape, _ = graph_util.infer_shape(graph, **shape)
shape.update(zip(graph.index.input_names, ishape)) shape.update(zip(graph.index.input_names, ishape))
...@@ -241,13 +264,14 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -241,13 +264,14 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
if _all_var_init: if _all_var_init:
init_var = initialize_variables(shape, dtype) init_var = initialize_variables(shape, dtype)
# Apply optimization # Apply optimization
graph = optimize(graph, shape, dtype) graph = optimize(graph, shape, dtype, layout)
# Precompute prune # Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"): if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params) graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generation # Operator Fusion and generation
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
graph = graph_attr.set_dtype_inputs(graph, dtype) graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", str(target), "str") graph._set_json_attr("target", str(target), "str")
if target_host is not None: if target_host is not None:
......
...@@ -96,11 +96,22 @@ def set_layout_inputs(g, layout): ...@@ -96,11 +96,22 @@ def set_layout_inputs(g, layout):
Returns Returns
------- -------
g : Graph g : Graph
The updated graph with updated dtype. The updated graph with updated layout.
""" """
list_shape = [ if isinstance(layout, dict):
layout.get(name, "default") for name in g.index.input_names] list_layout = [
g._set_json_attr("layout_inputs", list_shape, 'list_str') layout.get(name, "__undef__") for name in g.index.input_names]
elif isinstance(layout, str):
list_layout = ["__undef__"] * len(g.index.input_names)
list_layout[0] = layout
else:
raise ValueError("Input layout must be str or dict")
last_inferred_layouts = g.json_attr("layout")
if last_inferred_layouts:
input_layout = [last_inferred_layouts[g.index.entry_id(x)] for x in g.index.input_names]
for i, layout_stored in enumerate(input_layout):
list_layout[i] = list_layout[i] if list_layout[i] != '__undef__' else layout_stored
g._set_json_attr("layout_inputs", list_layout, 'list_layout')
return g return g
_move_out_module = tvm.get_global_func("nnvm.graph._move_module") _move_out_module = tvm.get_global_func("nnvm.graph._move_module")
......
"""Module space to register contrib functions. Leave empty"""
...@@ -86,6 +86,10 @@ def _conv2d(inputs, attrs): ...@@ -86,6 +86,10 @@ def _conv2d(inputs, attrs):
layout = attrs.get('layout', 'NCHW') layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']: if layout not in ['NCHW', 'NHWC']:
_raise_not_supported('layout: ' + layout, 'conv2d') _raise_not_supported('layout: ' + layout, 'conv2d')
if 'kernel_layout' in attrs:
kernel_layout = attrs['kernel_layout']
else:
kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW'
op_name, new_attrs = 'conv2d', {} op_name, new_attrs = 'conv2d', {}
new_attrs['channels'] = _required_attr(attrs, 'num_filter') new_attrs['channels'] = _required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel new_attrs['kernel_size'] = kernel
...@@ -94,6 +98,7 @@ def _conv2d(inputs, attrs): ...@@ -94,6 +98,7 @@ def _conv2d(inputs, attrs):
new_attrs['dilation'] = attrs.get('dilate', (1, 1)) new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout new_attrs['layout'] = layout
new_attrs['kernel_layout'] = kernel_layout
new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False' new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False'
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
...@@ -106,6 +111,10 @@ def _conv2d_transpose(inputs, attrs): ...@@ -106,6 +111,10 @@ def _conv2d_transpose(inputs, attrs):
layout = attrs.get('layout', 'NCHW') layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']: if layout not in ['NCHW', 'NHWC']:
_raise_not_supported('layout: ' + layout, 'conv2d_transpose') _raise_not_supported('layout: ' + layout, 'conv2d_transpose')
if 'kernel_layout' in attrs:
kernel_layout = attrs['kernel_layout']
else:
kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW'
op_name, new_attrs = 'conv2d_transpose', {} op_name, new_attrs = 'conv2d_transpose', {}
new_attrs['channels'] = _required_attr(attrs, 'num_filter') new_attrs['channels'] = _required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel new_attrs['kernel_size'] = kernel
...@@ -115,6 +124,7 @@ def _conv2d_transpose(inputs, attrs): ...@@ -115,6 +124,7 @@ def _conv2d_transpose(inputs, attrs):
new_attrs['dilation'] = attrs.get('dilate', (1, 1)) new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout new_attrs['layout'] = layout
new_attrs['kernel_layout'] = kernel_layout
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
...@@ -237,7 +247,7 @@ _convert_map = { ...@@ -237,7 +247,7 @@ _convert_map = {
'min_axis' : _rename('min'), 'min_axis' : _rename('min'),
'reshape' : _reshape, 'reshape' : _reshape,
'sum_axis' : _rename('sum'), 'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling 'UpSampling' : _upsampling,
} }
def _convert_symbol(op_name, inputs, attrs, def _convert_symbol(op_name, inputs, attrs,
......
...@@ -16,6 +16,7 @@ from . import _base ...@@ -16,6 +16,7 @@ from . import _base
from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init
from .attribute import AttrScope from .attribute import AttrScope
from . import _symbol_internal as _internal from . import _symbol_internal as _internal
from . import contrib
# Use different verison of SymbolBase # Use different verison of SymbolBase
# When possible, use cython to speedup part of computation. # When possible, use cython to speedup part of computation.
......
...@@ -5,7 +5,7 @@ from __future__ import absolute_import ...@@ -5,7 +5,7 @@ from __future__ import absolute_import
import tvm import tvm
import topi import topi
from topi.util import get_const_int from topi.util import get_const_int
from .tensor import _fschedule_broadcast from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg from . import registry as reg
from .registry import OpPattern from .registry import OpPattern
...@@ -32,6 +32,11 @@ reg.register_schedule("pad", _fschedule_broadcast) ...@@ -32,6 +32,11 @@ reg.register_schedule("pad", _fschedule_broadcast)
reg.register_pattern("pad", OpPattern.INJECTIVE) reg.register_pattern("pad", OpPattern.INJECTIVE)
# layout transform
reg.register_schedule("__layout_transform__", _fschedule_injective)
reg.register_pattern("__layout_transform__", OpPattern.INJECTIVE)
@reg.register_schedule("softmax") @reg.register_schedule("softmax")
def schedule_softmax(_, outs, target): def schedule_softmax(_, outs, target):
"""Schedule definition of softmax""" """Schedule definition of softmax"""
...@@ -108,6 +113,42 @@ def schedule_conv2d(attrs, outs, target): ...@@ -108,6 +113,42 @@ def schedule_conv2d(attrs, outs, target):
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# convolution NCHWc
@reg.register_compute("_contrib_conv2d_NCHWc")
def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
"""Compute definition of conv2d NCHWc"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), strides, padding)
else:
raise ValueError("not support arbitrary group number > 1 for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias)
return out
@reg.register_schedule("_contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups")
kh, kw = attrs.get_int_tuple('kernel_size')
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, outs)
else:
raise ValueError("not support group number > 1 for now")
reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d_transpose # conv2d_transpose
@reg.register_compute("conv2d_transpose") @reg.register_compute("conv2d_transpose")
......
...@@ -25,6 +25,7 @@ class OpPattern(object): ...@@ -25,6 +25,7 @@ class OpPattern(object):
_register_compute = tvm.get_global_func("nnvm._register_compute") _register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule") _register_schedule = tvm.get_global_func("nnvm._register_schedule")
_register_pattern = tvm.get_global_func("nnvm._register_pattern") _register_pattern = tvm.get_global_func("nnvm._register_pattern")
_register_alter_op_layout = tvm.get_global_func("nnvm.compiler._register_alter_op_layout")
def register_compute(op_name, f=None, level=10): def register_compute(op_name, f=None, level=10):
"""Register compute function for operator """Register compute function for operator
...@@ -93,3 +94,29 @@ def register_pattern(op_name, pattern, level=10): ...@@ -93,3 +94,29 @@ def register_pattern(op_name, pattern, level=10):
The priority level The priority level
""" """
_register_pattern(op_name, pattern, level) _register_pattern(op_name, pattern, level)
def register_alter_op_layout(op_name, f=None, level=10):
"""Register alter layout function for operator
Parameters
----------
op_name : str
The name of operator
f : function
The schedule function
level : int
The priority level
Returns
-------
fregister : function
Register function if f is not specified.
"""
def register(myf):
"""internal register function"""
_register_alter_op_layout(op_name, myf, level)
return myf
return register(f) if f else register
/*!
* Copyright (c) 2018 by Contributors
* \file alter_op_layout.cc
* \brief Alter the operator layouts. Keep inferred layouts (if any) from previous stages.
* e.g., convolution may calculates faster with NCHW16c layout.
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/layout.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/pass_functions.h>
#include <tvm/tvm.h>
#include <algorithm>
#include <functional>
#include "./compile_engine.h"
#include "./graph_transform.h"
namespace nnvm {
namespace compiler {
namespace {
tvm::Array<tvm::Tensor> GetTensorInfo(const IndexedGraph& idx_graph,
const uint32_t nid,
const ShapeVector& shape_vec,
const DTypeVector& dtype_vec) {
tvm::Array<tvm::Tensor> vec;
for (uint32_t i = 0; i < idx_graph[nid].source->num_outputs(); ++i) {
tvm::Array<tvm::Expr> shape;
for (int64_t x : shape_vec[idx_graph.entry_id(nid, i)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
shape.push_back(tvm::make_const(tvm::Int(32), x));
}
vec.push_back(tvm::placeholder(
shape, GetTVMType(dtype_vec[idx_graph.entry_id(nid, i)])));
}
return vec;
}
Graph AlterOpLayout(const Graph& src) {
static auto& falter_op_layout =
Op::GetAttr<nnvm::compiler::FTVMAlterOpLayout >("FTVMAlterOpLayout");
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = src.GetAttr<DTypeVector>("dtype");
const IndexedGraph& idx_graph = src.indexed_graph();
std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes());
std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes());
std::unordered_map<const Node*, uint32_t> new_nodes;
if (src.HasAttr("layout")) {
// record layouts so that LayoutTransform pass can fix layouts correctly,
// e.g., conv2d can be replaced by some contrib implement
// whose layout is different from the original one
// (which was imported from a model file).
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) {
const auto &inode = idx_graph[nid];
if (falter_op_layout.count(inode.source->op())) {
// do not record input layouts of nodes that will be replaced.
continue;
}
std::vector<Layout> in_layout;
for (const auto& e : inode.inputs) {
in_layout.emplace_back(layouts[idx_graph.entry_id(e)]);
}
in_layouts_of_node[nid] = in_layout;
std::vector<Layout> out_layout;
for (uint i = 0; i < inode.source->num_outputs(); ++i) {
out_layout.emplace_back(layouts[idx_graph.entry_id(nid, i)]);
}
out_layouts_of_node[nid] = out_layout;
}
}
auto transform = [&](uint32_t nid,
const NodePtr& n,
std::vector<NodeEntry>* ret) {
nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout =
falter_op_layout.get(n->op(), nullptr);
if (fn_alter_op_layout == nullptr) {
new_nodes[n.get()] = nid;
return false;
}
// construct parameters for registered function
std::vector<Symbol> op_inputs;
tvm::Array<tvm::Tensor> tensor_infos;
CHECK_EQ(n->num_inputs(), idx_graph[nid].inputs.size());
for (uint32_t i = 0; i < n->num_inputs(); ++i) {
const nnvm::NodeEntry& input = n->inputs[i];
// input operator
Symbol op_input;
op_input.outputs.push_back(input);
op_inputs.push_back(op_input);
// input tinfo, extract from the original graph
// because it was where infer_shape & infer_type applied.
tvm::Array<tvm::Tensor> op_output_tinfos =
GetTensorInfo(idx_graph, idx_graph[nid].inputs[i].node_id,
shape_vec, dtype_vec);
tensor_infos.push_back(op_output_tinfos[input.index]);
}
// callback registered function to get a new operator.
auto op = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos);
*ret = op.outputs;
return true;
};
Graph ret = nnvm::compiler::GraphTransform(src, transform);
if (src.HasAttr("layout")) {
// restore the layouts to return graph
const auto& ret_idx = ret.indexed_graph();
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid];
if (new_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]];
for (const auto& e : inode.inputs) {
ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index];
}
const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]];
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i];
}
}
}
// cannot call indexed_graph() before return the origin Graph,
// thus create a new one.
nnvm::Graph new_ret;
new_ret.outputs = ret.outputs;
new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts));
return new_ret;
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(AlterOpLayout)
.set_body(AlterOpLayout)
.set_change_graph(true);
} // namespace
} // namespace compiler
} // namespace nnvm
...@@ -362,7 +362,7 @@ bool Pool2DBackward( ...@@ -362,7 +362,7 @@ bool Pool2DBackward(
std::vector<FoldChainInfo>* in_axis) { std::vector<FoldChainInfo>* in_axis) {
using top::Pool2DParam; using top::Pool2DParam;
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed); const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
if (out_info.axis == 1 && param.layout == top::kNCHW) { if (out_info.axis == 1 && param.layout == "NCHW") {
(*in_axis)[0] = out_info; (*in_axis)[0] = out_info;
} }
return false; return false;
...@@ -376,7 +376,7 @@ bool Pool2DForward( ...@@ -376,7 +376,7 @@ bool Pool2DForward(
FoldChainInfo* out_info) { FoldChainInfo* out_info) {
using top::Pool2DParam; using top::Pool2DParam;
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed); const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
if ((*in_info)[0].axis == 1 && param.layout == top::kNCHW) { if ((*in_info)[0].axis == 1 && param.layout == "NCHW") {
*out_info = (*in_info)[0]; *out_info = (*in_info)[0];
} }
return false; return false;
...@@ -467,7 +467,7 @@ bool Conv2DScaleAxisBackward( ...@@ -467,7 +467,7 @@ bool Conv2DScaleAxisBackward(
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if (out_info.kind != kPending) return false; if (out_info.kind != kPending) return false;
// only optimize for nchw for now // only optimize for nchw for now
if (param.layout == top::kNCHW && out_info.axis == 1) { if (param.layout == "NCHW" && out_info.axis == 1) {
(*in_axis)[1].kind = kMulConsumer; (*in_axis)[1].kind = kMulConsumer;
(*in_axis)[1].axis = 0; (*in_axis)[1].axis = 0;
(*in_axis)[1].source = out_info.source; (*in_axis)[1].source = out_info.source;
...@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward( ...@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward(
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if ((*in_info)[0].kind != kPending) return false; if ((*in_info)[0].kind != kPending) return false;
// only optimize for nchw for now // only optimize for nchw for now
if (param.layout == top::kNCHW && (*in_info)[0].axis == 1) { if (param.layout == "NCHW" && (*in_info)[0].axis == 1) {
(*in_info)[1].kind = kMulConsumer; (*in_info)[1].kind = kMulConsumer;
(*in_info)[1].axis = 1; (*in_info)[1].axis = 1;
(*in_info)[1].source = (*in_info)[0].source; (*in_info)[1].source = (*in_info)[0].source;
......
/*!
* Copyright (c) 2017 by Contributors
* \file layout_transform.cc
* \brief Transforms layout.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/contrib_op_param.h>
namespace nnvm {
namespace compiler {
const TLayoutInfo& GetDefaultLayout() {
static TLayoutInfo default_layout = "default";
return default_layout;
}
nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
const std::string& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src + "_to_" + dst + std::to_string(count++);
n->attrs.dict["src_layout"] = src;
n->attrs.dict["dst_layout"] = dst;
n->op()->attr_parser(&(n->attrs));
return n;
}
/*!
* \brief A simple layout transform pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph LayoutTransform(nnvm::Graph src) {
static auto& op_layout_request =
nnvm::Op::GetAttr<FTVMLayoutRequest>("FTVMLayoutRequest");
static auto& op_vecop =
nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp");
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const std::vector<TLayoutInfo>& input_layouts =
src.GetAttr<std::vector<TLayoutInfo> >("layout_inputs");
const IndexedGraph& idx = src.indexed_graph();
std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout());
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) {
if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
is_map = false;
break;
}
}
}
return is_map;
};
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *(inode.source);
if (new_node->is_variable()) {
auto input_iter = std::find(
idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
size_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
produce_vec[idx.entry_id(nid, 0)] = input_layouts[input_id];
mirror_vec[nid] = new_node;
continue;
}
if (op_vecop.count(inode.source->op())) {
new_node = op_vecop[inode.source->op()](inode.source);
new_node->inputs.resize(new_node->num_inputs());
}
// set up output and input layouts
std::vector<TLayoutInfo> request_ilayouts(new_node->num_inputs(), GetDefaultLayout());
if (op_layout_request.count(new_node->op())) {
std::vector<TLayoutInfo> produce_olayouts(new_node->num_outputs(), GetDefaultLayout());
CHECK(op_layout_request[new_node->op()](
new_node->attrs, &request_ilayouts, &produce_olayouts))
<< "Layout request fail";
CHECK_EQ(request_ilayouts.size(), new_node->num_inputs());
CHECK_EQ(produce_olayouts.size(), new_node->num_outputs());
for (size_t i = 0; i < new_node->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = produce_olayouts[i];
}
}
bool map_layout = is_map_op(nid);
if (map_layout) {
const TLayoutInfo& layout = produce_vec[idx.entry_id(inode.inputs[0])];
for (const auto& e : inode.inputs) {
if (produce_vec[idx.entry_id(e)] != layout) {
map_layout = false;
break;
}
}
if (map_layout) {
for (size_t i = 0; i < inode.source->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = layout;
}
}
}
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] =
nnvm::NodeEntry{in, e.index, e.version};
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
TLayoutInfo request = request_ilayouts[i];
if (!map_layout && (produce != request)) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_" + request;
tnode->inputs.emplace_back(new_node->inputs[i]);
new_node->inputs[i] = nnvm::NodeEntry{tnode, 0, 0};
}
}
mirror_vec[nid] = new_node;
}
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : idx.outputs()) {
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
if (produce != GetDefaultLayout()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, GetDefaultLayout());
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_default";
tnode->inputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
} else {
outputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
}
nnvm::Graph ret;
ret.outputs = std::move(outputs);
return ret;
}
} // namespace compiler
} // namespace nnvm
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/compiler/packed_func_ext.h> #include <nnvm/compiler/packed_func_ext.h>
#include <nnvm/compiler/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h>
#include <tvm/runtime/c_runtime_api.h>
#include "./node_attr.h" #include "./node_attr.h"
#include "compile_engine.h" #include "compile_engine.h"
...@@ -62,6 +63,23 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys") ...@@ -62,6 +63,23 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys")
*rv = keys; *rv = keys;
}); });
TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
.set_body([](TVMArgs args, TVMRetValue *rv) {
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fpack = [f](const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos) {
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos);
CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code
<< ") but get code = " << ret.type_code();
return *(static_cast<Symbol*>(ret.value().v_handle));
};
op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]);
});
// custom version of TVM compute // custom version of TVM compute
TVM_REGISTER_GLOBAL("nnvm._register_compute") TVM_REGISTER_GLOBAL("nnvm._register_compute")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -22,7 +22,8 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -22,7 +22,8 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
nnvm::NodeEntry beta, nnvm::NodeEntry beta,
nnvm::NodeEntry moving_mean, nnvm::NodeEntry moving_mean,
nnvm::NodeEntry moving_var, nnvm::NodeEntry moving_var,
TShape dshape) { TShape dshape,
TShape bshape) {
CHECK_NE(dshape.ndim(), 0); CHECK_NE(dshape.ndim(), 0);
CHECK(attrs.op); CHECK(attrs.op);
static const Op* bn_op = Op::Get("batch_norm"); static const Op* bn_op = Op::Get("batch_norm");
...@@ -60,13 +61,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -60,13 +61,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
"elemwise_add", bn_name + "_add_beta", {shift, beta}); "elemwise_add", bn_name + "_add_beta", {shift, beta});
} }
int axis = param.axis; int axis = param.axis;
scale = ExpandBiasToMatchAxis(scale, dshape.ndim(), 1, axis); scale = ExpandBiasToMatchAxis(scale, dshape.ndim()-bshape.ndim()+1, 1, axis);
shift = ExpandBiasToMatchAxis(shift, dshape.ndim(), 1, axis); shift = ExpandBiasToMatchAxis(shift, dshape.ndim()-bshape.ndim()+1, 1, axis);
NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data", NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
{data, scale}); {data, scale});
out = MakeNode("broadcast_add", bn_name + "_out", out = MakeNode("broadcast_add", bn_name + "_out",
{out, shift}); {out, shift});
// It is invalid to ref the other values of BN after infernece transform. // It is invalid to ref the other values of BN after inference transform.
NodeEntry undef = MakeNode("__undef__", "undef", {}); NodeEntry undef = MakeNode("__undef__", "undef", {});
return {out, undef, undef}; return {out, undef, undef};
} }
...@@ -87,7 +89,8 @@ Graph SimplifyInference(nnvm::Graph src) { ...@@ -87,7 +89,8 @@ Graph SimplifyInference(nnvm::Graph src) {
n->inputs[2], n->inputs[2],
n->inputs[3], n->inputs[3],
n->inputs[4], n->inputs[4],
shape_vec[idx.entry_id(nid, 0)]); shape_vec[idx.entry_id(nid, 0)],
shape_vec[idx.entry_id(nid, 1)]);
return true; return true;
} else if (n->op() == dropout_op) { } else if (n->op() == dropout_op) {
NodeEntry undef = MakeNode("__undef__", "undef", {}); NodeEntry undef = MakeNode("__undef__", "undef", {});
...@@ -101,7 +104,8 @@ Graph SimplifyInference(nnvm::Graph src) { ...@@ -101,7 +104,8 @@ Graph SimplifyInference(nnvm::Graph src) {
} }
NNVM_REGISTER_PASS(SimplifyInference) NNVM_REGISTER_PASS(SimplifyInference)
.set_body(SimplifyInference); .set_body(SimplifyInference)
.set_change_graph(true);
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
/*!
* Copyright (c) 2018 by Contributors
* \file correct_layout.cc
* \brief Infer and correct layout.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/layout.h>
namespace nnvm {
namespace pass {
nnvm::NodePtr CreateLayoutTransformNode(const Layout& src,
const Layout& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++);
n->attrs.dict["src_layout"] = src.name();
n->attrs.dict["dst_layout"] = dst.name();
n->op()->attr_parser(&(n->attrs));
return n;
}
using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >;
/*!
* \brief A simple layout infer pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph CorrectLayout(nnvm::Graph src) {
static auto& op_infer_layout =
nnvm::Op::GetAttr<FInferLayout>("FInferLayout");
const IndexedGraph& idx = src.indexed_graph();
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// (new) NodePtr -> output_layouts
LayoutAttrDict new_layouts;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *(inode.source);
if (new_node->is_variable()) {
// Variable node. No operator. Only one output entry.
auto input_iter = std::find(
idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
if (src.HasAttr("layout_inputs")) {
new_layouts[new_node.get()] =
{src.GetAttr<std::vector<Layout> >("layout_inputs")[input_id]};
} else {
new_layouts[new_node.get()] = {Layout::Undef()};
}
mirror_vec[nid] = new_node;
continue;
}
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
// set up output and input layouts
std::vector<Layout> request_ilayouts(num_inputs, Layout::Undef());
for (size_t i = 0; i < num_inputs; ++i) {
const IndexedGraph::NodeEntry& input_entry = inode.inputs[i];
const NodePtr& new_input_node = mirror_vec[input_entry.node_id];
CHECK(new_input_node != nullptr);
// fill inputs by previous node (DFS order) inferred layouts.
const auto& layouts_iter = new_layouts.find(new_input_node.get());
CHECK(layouts_iter != new_layouts.end());
request_ilayouts[i] = layouts_iter->second[input_entry.index];
}
// layouts produced by previous node.
std::vector<Layout> produce_ilayouts(request_ilayouts);
// input layouts from last pass of LayoutTransform (if apply)
std::vector<Layout> last_request_ilayouts(num_inputs, Layout::Undef());
// fill outputs by last pass of LayoutTransform (if apply)
std::vector<Layout> produce_olayouts(num_outputs, Layout::Undef());
if (src.HasAttr("layout")) {
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t i = 0; i < num_outputs; ++i) {
produce_olayouts[i] = layouts[idx.entry_id(nid, i)];
}
for (uint32_t i = 0; i < num_inputs; ++i) {
last_request_ilayouts[i] = layouts[idx.entry_id(inode.inputs[i])];
}
}
const auto& flayout = op_infer_layout[new_node->op()];
CHECK(flayout != nullptr) << "Attribute FInferLayout"
<< " is not registered by op " << inode.source->op()->name
<< " we are not able to complete layout transform.";
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
<< "Layout infer fail";
CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs);
// update new layouts
new_layouts[new_node.get()] = std::move(produce_olayouts);
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version};
// insert layout_transform if necessary
const Layout& produce = produce_ilayouts[i];
const Layout& request = request_ilayouts[i];
if (produce != request && produce.defined()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name();
tnode->inputs.emplace_back(new_node->inputs[i]);
nnvm::NodeEntry tnode_output{tnode, 0, 0};
new_node->inputs[i] = tnode_output;
// layout produced by LayoutTransformNode
new_layouts[tnode.get()] = {request};
} else if (!produce.defined()) {
// do reverse infer
new_layouts[in.get()][e.index] = request;
}
}
mirror_vec[nid] = new_node;
}
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : idx.outputs()) {
outputs.emplace_back(nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
nnvm::Graph ret;
ret.outputs = outputs;
// restore the layouts to return graph
const auto& ret_idx = ret.indexed_graph();
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid];
const auto& layout_iter = new_layouts.find(inode.source);
if (layout_iter != new_layouts.end()) {
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = std::move(layout_iter->second[i]);
}
}
}
// cannot call indexed_graph() before return the origin Graph,
// thus create a new one
nnvm::Graph new_ret;
new_ret.outputs = std::move(outputs);
new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts));
return new_ret;
}
// register pass
NNVM_REGISTER_PASS(CorrectLayout)
.describe("Return a layout-transformed graph of src.")
.set_body(CorrectLayout)
.provide_graph_attr("layout")
.set_change_graph(true);
DMLC_JSON_ENABLE_ANY(LayoutVector, list_layout);
} // namespace pass
} // namespace nnvm
...@@ -158,7 +158,7 @@ Graph InferAttr(Graph &&ret, ...@@ -158,7 +158,7 @@ Graph InferAttr(Graph &&ret,
} else { } else {
CHECK(!last_iter) CHECK(!last_iter)
<< "Attribute " << infer_name << "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name << " is not registered by op " << inode.source->op()->name
<< " we are not able to complete the inference because of this"; << " we are not able to complete the inference because of this";
} }
} }
......
...@@ -6,9 +6,12 @@ ...@@ -6,9 +6,12 @@
#ifndef NNVM_TOP_ELEMWISE_OP_COMMON_H_ #ifndef NNVM_TOP_ELEMWISE_OP_COMMON_H_
#define NNVM_TOP_ELEMWISE_OP_COMMON_H_ #define NNVM_TOP_ELEMWISE_OP_COMMON_H_
#include <nnvm/layout.h>
#include <nnvm/top/nn.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <functional>
#include "./op_common.h" #include "./op_common.h"
namespace nnvm { namespace nnvm {
...@@ -100,12 +103,176 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ...@@ -100,12 +103,176 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs,
attrs, in_attrs, out_attrs, -1); attrs, in_attrs, out_attrs, -1);
} }
template<int n_in, int n_out>
inline bool ElemwiseFixedLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts,
const std::function<Layout(const Layout& in)>& finfer) {
const size_t in_size = (n_in == -1) ? in_layouts->size() : static_cast<size_t>(n_in);
const size_t out_size = (n_out == -1) ? out_layouts->size() : static_cast<size_t>(n_out);
auto deduce = [&](Layout *target, const std::vector<Layout> *vec,
size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
if (vec->at(i).defined()) {
if (!target->defined()) {
*target = vec->at(i);
}
CHECK_EQ(*target, vec->at(i))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << *target
<< ", got " << vec->at(i);
}
}
};
Layout in, last_in, out;
deduce(&in, in_layouts, in_size, "input");
deduce(&last_in, last_in_layouts, in_size, "input (last infer pass)");
deduce(&out, out_layouts, out_size, "output");
if (!last_in.defined()) {
last_in = in;
} else {
// else we copy in_layout produced by last infer pass to in_layout,
// and let LayoutTransform pass
// to insert an layout_transform node to fix the input layout.
in = last_in;
}
out = finfer(in);
auto write = [](std::vector<Layout> *vec, Layout& value, size_t size) {
for (size_t i = 0; i < size; ++i) {
vec->at(i) = value;
}
};
if (in.defined()) write(in_layouts, in, in_size);
if (out.defined()) write(out_layouts, out, out_size);
return true;
}
/*! \brief Fix the input layout as the previous inferred (if any) and copy to output */
template<int n_in, int n_out>
inline bool ElemwiseFixedLayoutCopyToOut(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
return ElemwiseFixedLayout<n_in, n_out>(
attrs, in_layouts, last_in_layouts, out_layouts, [](const Layout& in) {
return in;
});
}
/*! \brief Fix the input layout as the previous inferred (if any) and do not define output */
template<int n_in, int n_out>
inline bool ElemwiseFixedLayoutUnknownOut(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
return ElemwiseFixedLayout<n_in, n_out>(
attrs, in_layouts, last_in_layouts, out_layouts, [](const Layout& in) {
return Layout::Undef();
});
}
/*! \brief take arbitrary input layout and copy to output */
template<int n_in, int n_out>
inline bool ElemwiseArbitraryLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const size_t in_size = (n_in == -1) ? in_layouts->size() : static_cast<size_t>(n_in);
const size_t out_size = (n_out == -1) ? out_layouts->size() : static_cast<size_t>(n_out);
Layout in;
for (size_t i = 0; i < in_size; ++i) {
if (!in.defined()) in = in_layouts->at(i);
CHECK_EQ(in, in_layouts->at(i))
<< "Incompatible attr in node " << attrs.name << " at " << i
<< "-th input: expected " << in
<< ", got " << in_layouts->at(i);
}
if (in.defined()) {
for (size_t i = 0; i < out_size; ++i) {
out_layouts->at(i) = in;
}
}
return true;
}
/*!
* \brief try to convert right layout to left layout if they are different.
* if the converting fails, it will use the last inferred layouts.
*/
inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
CHECK_EQ(in_layouts->size(), 2U);
CHECK_EQ(last_in_layouts->size(), 2U);
CHECK_EQ(out_layouts->size(), 1U);
const Layout& lhs_last = (*last_in_layouts)[0];
const Layout& rhs_last = (*last_in_layouts)[1];
CHECK((lhs_last.defined() && rhs_last.defined()) ||
(!lhs_last.defined() && !rhs_last.defined()));
const Layout& lhs = (*in_layouts)[0];
const Layout& rhs = (*in_layouts)[1];
if (!lhs.defined() && !rhs.defined()) {
CHECK(!lhs_last.defined() && !rhs_last.defined())
<< "Lost input layouts in node " << attrs.name
<< ": last inferred lhs=" << lhs_last << ", rhs=" << rhs_last;
return true;
} else if (!lhs.defined()) {
CHECK(!lhs_last.defined() && !rhs_last.defined());
in_layouts->at(0) = rhs;
out_layouts->at(0) = rhs;
return true;
} else if (!rhs.defined()) {
CHECK(!lhs_last.defined() && !rhs_last.defined());
in_layouts->at(1) = lhs;
out_layouts->at(0) = lhs;
return true;
}
if (lhs == rhs) {
// for same layout, we can always do binary calculation
// and pass the layout to next layer
out_layouts->at(0) = lhs;
return true;
}
if (rhs.convertible(lhs)) {
in_layouts->at(1) = lhs;
out_layouts->at(0) = lhs;
} else {
CHECK(lhs_last.defined() && rhs_last.defined())
<< "Incompatible input layouts in node " << attrs.name
<< ". lhs: " << lhs << ", rhs: " << rhs;
CHECK(lhs_last == rhs_last);
in_layouts->at(0) = lhs_last;
in_layouts->at(1) = rhs_last;
out_layouts->at(0) = lhs_last;
}
return true;
}
#define NNVM_REGISTER_ELEMWISE_UNARY_OP(name) \ #define NNVM_REGISTER_ELEMWISE_UNARY_OP(name) \
NNVM_REGISTER_OP(name) \ NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \ .set_num_inputs(1) \
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) \ .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseArbitraryLayout<1, 1>) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \ [](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \ return std::vector<std::pair<int, int> >{{0, 0}}; \
...@@ -131,6 +298,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ...@@ -131,6 +298,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs,
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>) \ .set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseBinaryKeepLeftLayout) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \ [](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
...@@ -150,6 +319,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ...@@ -150,6 +319,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs,
ParamGetAttrDict<ElementWiseReduceParam>) \ ParamGetAttrDict<ElementWiseReduceParam>) \
.set_attr<nnvm::FInferShape>("FInferShape", \ .set_attr<nnvm::FInferShape>("FInferShape", \
ElementWiseReduceShape) \ ElementWiseReduceShape) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseFixedLayoutCopyToOut<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \ .set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \
.add_argument("args", "Symbol[]", "Positional input arguments") .add_argument("args", "Symbol[]", "Positional input arguments")
...@@ -166,6 +337,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ...@@ -166,6 +337,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs,
static_cast<int>(kFloat32)); \ static_cast<int>(kFloat32)); \
return true; \ return true; \
}) \ }) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_attr<FGradient>( \ .set_attr<FGradient>( \
"FGradient", [](const NodePtr& n, \ "FGradient", [](const NodePtr& n, \
const std::vector<NodeEntry>& ograds) { \ const std::vector<NodeEntry>& ograds) { \
......
...@@ -5,11 +5,22 @@ ...@@ -5,11 +5,22 @@
*/ */
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/layout.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h> #include <nnvm/top/nn.h>
#include <tvm/tensor.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/compiler/op_attr_types.h>
#include <tvm/tvm.h>
#include "./nn_common.h" #include "./nn_common.h"
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/nn.h"
using tvm::Tensor;
using tvm::Array;
using nnvm::compiler::FTVMCompute;
namespace nnvm { namespace nnvm {
namespace top { namespace top {
...@@ -20,7 +31,26 @@ DMLC_REGISTER_PARAMETER(Conv2DParam); ...@@ -20,7 +31,26 @@ DMLC_REGISTER_PARAMETER(Conv2DParam);
inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape, std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) { std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
const Layout in_layout(param.layout);
const Layout kernel_layout(param.kernel_layout);
CHECK(in_layout.convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param.out_layout);
if (!out_layout.defined()) out_layout = in_layout;
CHECK(out_layout.convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
if (param.use_bias) { if (param.use_bias) {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
} else { } else {
...@@ -30,7 +60,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -30,7 +60,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
TShape dshape = in_shape->at(0); TShape dshape = in_shape->at(0);
if (dshape.ndim() == 0) return false; if (dshape.ndim() == 0) return false;
dshape = ConvertLayout(dshape, param.layout, kNCHW); dshape = ConvertLayout(dshape, in_layout, kNCHW);
CHECK_EQ(dshape.ndim(), 4U) << "Input data should be 4D"; CHECK_EQ(dshape.ndim(), 4U) << "Input data should be 4D";
CHECK_EQ(param.kernel_size.ndim(), 2U); CHECK_EQ(param.kernel_size.ndim(), 2U);
...@@ -48,13 +78,20 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -48,13 +78,20 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
param.kernel_size[0], param.kernel_size[0],
param.kernel_size[1]}); param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout, true); wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape[0] *= param.groups; wshape[0] *= param.groups;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
if (param.use_bias) { if (param.use_bias) {
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, static const Layout default_bias_layout("C");
Conv2DParam::kBias, TShape({param.channels})); TShape bias_shape({param.channels});
auto oc_block = out_layout.subsizeof('C');
if (oc_block > 0) {
size_t split_axis = (out_layout.indexof('C') < out_layout.indexof('c')) ? 1 : 0;
bias_shape = ConvertLayout(bias_shape, default_bias_layout,
default_bias_layout.split('C', split_axis, oc_block));
}
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kBias, bias_shape);
} }
// dilation // dilation
dim_t dilated_ksize_y = 1 + (param.kernel_size[0] - 1) * param.dilation[0]; dim_t dilated_ksize_y = 1 + (param.kernel_size[0] - 1) * param.dilation[0];
...@@ -66,12 +103,11 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -66,12 +103,11 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
if (dshape[3] != 0) { if (dshape[3] != 0) {
oshape[3] = (dshape[3] + param.padding[1] * 2 - dilated_ksize_x) / param.strides[1] + 1; oshape[3] = (dshape[3] + param.padding[1] * 2 - dilated_ksize_x) / param.strides[1] + 1;
} }
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, ConvertLayout(oshape, kNCHW, out_layout));
ConvertLayout(oshape, kNCHW, param.layout));
// Perform incomplete shape inference. Fill in the missing values in data shape. // Perform incomplete shape inference. Fill in the missing values in data shape.
// 1) We can always fill in the batch_size. // 1) We can always fill in the batch_size.
// 2) We can back-calculate the input height/width if the corresponding stride is 1. // 2) We can back-calculate the input height/width if the corresponding stride is 1.
oshape = ConvertLayout((*out_shape)[0], param.layout, kNCHW); oshape = ConvertLayout((*out_shape)[0], out_layout, kNCHW);
dshape[0] = oshape[0]; dshape[0] = oshape[0];
if (oshape[2] && param.strides[0] == 1) { if (oshape[2] && param.strides[0] == 1) {
dshape[2] = oshape[2] + dilated_ksize_y - 1 - 2 * param.padding[0]; dshape[2] = oshape[2] + dilated_ksize_y - 1 - 2 * param.padding[0];
...@@ -80,7 +116,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -80,7 +116,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
dshape[3] = oshape[3] + dilated_ksize_x - 1 - 2 * param.padding[1]; dshape[3] = oshape[3] + dilated_ksize_x - 1 - 2 * param.padding[1];
} }
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kData, NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kData,
ConvertLayout(dshape, kNCHW, param.layout)); ConvertLayout(dshape, kNCHW, in_layout));
// Check whether the kernel sizes are valid // Check whether the kernel sizes are valid
if (dshape[2] != 0) { if (dshape[2] != 0) {
CHECK_LE(dilated_ksize_y, dshape[2] + 2 * param.padding[0]) CHECK_LE(dilated_ksize_y, dshape[2] + 2 * param.padding[0])
...@@ -93,6 +129,41 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -93,6 +129,41 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool Conv2DInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
const Layout in_layout(param.layout);
Layout out_layout(param.out_layout);
if (!out_layout.defined()) out_layout = in_layout;
const Layout kernel_layout(param.kernel_layout);
if (param.use_bias) {
CHECK_EQ(ilayouts->size(), 3U) << "Input:[data, weight, bias]";
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, in_layout);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kernel_layout);
// automatically decide bias layout
Layout bias_layout("C");
auto oc_block = out_layout.subsizeof('C');
if (oc_block > 0) {
size_t split_axis = (out_layout.indexof('C') < out_layout.indexof('c')) ? 1 : 0;
bias_layout = bias_layout.split('C', split_axis, oc_block);
}
NNVM_ASSIGN_LAYOUT(*ilayouts, 2, bias_layout);
} else {
CHECK_EQ(ilayouts->size(), 2U) << "Input:[data, weight]";
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, in_layout);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kernel_layout);
}
CHECK_EQ(olayouts->size(), 1U);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, out_layout);
return true;
}
NNVM_REGISTER_OP(conv2d) NNVM_REGISTER_OP(conv2d)
.describe(R"code(2D convolution layer (e.g. spatial convolution over images). .describe(R"code(2D convolution layer (e.g. spatial convolution over images).
...@@ -118,6 +189,7 @@ a bias vector is created and added to the outputs. ...@@ -118,6 +189,7 @@ a bias vector is created and added to the outputs.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape) .set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DInferLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>) .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2) .set_support_level(2)
...@@ -130,6 +202,23 @@ a bias vector is created and added to the outputs. ...@@ -130,6 +202,23 @@ a bias vector is created and added to the outputs.
n->attrs.dict); n->attrs.dict);
}); });
NNVM_REGISTER_OP(_contrib_conv2d_NCHWc)
.describe(R"code(2D convolution layer (e.g. spatial convolution over images).
)code" NNVM_ADD_FILELINE)
.add_argument("data", "5D Tensor", "Packed input data.")
.add_argument("weight", "6D Tensor", "Packed weight matrix.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(Conv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<Conv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DInferLayout)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2);
NNVM_REGISTER_OP(_conv2d_grad) NNVM_REGISTER_OP(_conv2d_grad)
.describe(R"code(2D convolution grad. .describe(R"code(2D convolution grad.
...@@ -163,16 +252,21 @@ DMLC_REGISTER_PARAMETER(Conv2DTransposeParam); ...@@ -163,16 +252,21 @@ DMLC_REGISTER_PARAMETER(Conv2DTransposeParam);
inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape, std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) { std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const Conv2DTransposeParam& param = nnvm::get<Conv2DTransposeParam>(attrs.parsed); const Conv2DTransposeParam& param = nnvm::get<Conv2DTransposeParam>(attrs.parsed);
const Layout layout(param.layout);
const Layout kernel_layout(param.kernel_layout);
if (param.use_bias) { if (param.use_bias) {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
} else { } else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
} }
CHECK_EQ(out_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U);
const TShape& dshape = (*in_shape)[Conv2DTransposeParam::kData]; const TShape& dshape = (*in_shape)[Conv2DTransposeParam::kData];
if (dshape.ndim() == 0) return false; if (dshape.ndim() == 0) return false;
TShape dshape_nchw = ConvertLayout(dshape, param.layout, kNCHW); TShape dshape_nchw = ConvertLayout(dshape, layout, kNCHW);
CHECK_EQ(dshape_nchw[1] % param.groups, 0U) CHECK_EQ(dshape_nchw[1] % param.groups, 0U)
<< "input num_filter must divide group size"; << "input num_filter must divide group size";
...@@ -189,7 +283,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, ...@@ -189,7 +283,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
param.channels / param.groups, param.channels / param.groups,
param.kernel_size[0], param.kernel_size[0],
param.kernel_size[1]}); param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout, true); wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);
if (param.use_bias) { if (param.use_bias) {
...@@ -208,7 +302,33 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, ...@@ -208,7 +302,33 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
oshape[3] = (param.strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - oshape[3] = (param.strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param.padding[1] + param.output_padding[1]); 2 * param.padding[1] + param.output_padding[1]);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0,
ConvertLayout(oshape, kNCHW, param.layout)); ConvertLayout(oshape, kNCHW, layout));
return true;
}
inline bool Conv2DTransposeInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const Conv2DTransposeParam& param = nnvm::get<Conv2DTransposeParam>(attrs.parsed);
const Layout in_layout(param.layout);
const Layout kernel_layout(param.kernel_layout);
if (param.use_bias) {
CHECK_EQ(ilayouts->size(), 3U) << "Input:[data, weight, bias]";
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, in_layout);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kernel_layout);
NNVM_ASSIGN_LAYOUT(*ilayouts, 2, Layout("C"));
} else {
CHECK_EQ(ilayouts->size(), 2U) << "Input:[data, weight]";
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, in_layout);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kernel_layout);
}
CHECK_EQ(olayouts->size(), 1U);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, in_layout);
return true; return true;
} }
...@@ -243,6 +363,7 @@ said convolution. ...@@ -243,6 +363,7 @@ said convolution.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DTransposeParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DTransposeParam>)
.set_attr<FInferShape>("FInferShape", Conv2DTransposeInferShape) .set_attr<FInferShape>("FInferShape", Conv2DTransposeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DTransposeInferLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DTransposeParam>) .set_num_inputs(UseBiasNumInputs<Conv2DTransposeParam>)
.set_support_level(2); .set_support_level(2);
......
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
* \file nn.cc * \file nn.cc
* \brief Property def of nn operators. * \brief Property def of nn operators.
*/ */
#include <tvm/tvm.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/layout.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h> #include <nnvm/top/nn.h>
...@@ -20,6 +22,8 @@ ...@@ -20,6 +22,8 @@
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using tvm::Var;
using tvm::Expr;
using tvm::Tensor; using tvm::Tensor;
using tvm::Array; using tvm::Array;
using nnvm::compiler::FTVMCompute; using nnvm::compiler::FTVMCompute;
...@@ -82,6 +86,8 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored. ...@@ -82,6 +86,8 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>)
.set_attr<FInferShape>("FInferShape", DenseInferShape) .set_attr<FInferShape>("FInferShape", DenseInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
// leave weight & bias layout undefined
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -161,6 +167,7 @@ NNVM_REGISTER_OP(dropout) ...@@ -161,6 +167,7 @@ NNVM_REGISTER_OP(dropout)
.set_num_outputs(2) .set_num_outputs(2)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 2>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 2>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 2>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 2>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) { .set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) {
return 1; return 1;
}) })
...@@ -184,13 +191,75 @@ inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs, ...@@ -184,13 +191,75 @@ inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs,
CHECK((size_t)param.axis < dshape.Size()); CHECK((size_t)param.axis < dshape.Size());
TShape bshape({dshape[param.axis]}); TShape bshape({dshape[param.axis]});
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, bshape); if (in_shape->at(1).ndim() == 0) in_shape->at(1) = bshape;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 2, bshape); if (in_shape->at(2).ndim() == 0) in_shape->at(2) = bshape;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 3, bshape); if (in_shape->at(3).ndim() == 0) in_shape->at(3) = bshape;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 4, bshape); if (in_shape->at(4).ndim() == 0) in_shape->at(4) = bshape;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 1, bshape); out_shape->at(1) = in_shape->at(3);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 2, bshape); out_shape->at(2) = in_shape->at(4);
return true;
}
inline bool BatchNormInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
CHECK_EQ(in_layouts->size(), 5U);
CHECK_EQ(last_in_layouts->size(), 5U);
CHECK_EQ(out_layouts->size(), 3U);
Layout data_layout = in_layouts->at(0);
const Layout& origin_data_layout = last_in_layouts->at(0);
Layout param_layout("C");
if (data_layout.defined()) {
if (data_layout.indexof('C') != param.axis) {
CHECK(origin_data_layout.defined())
<< "Channel in data layout " << data_layout
<< " is not at index " << param.axis;
// convert it to the original one.
data_layout = origin_data_layout;
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, origin_data_layout);
} else if (data_layout.indexof('c') >= 0 &&
static_cast<uint32_t>(data_layout.indexof('c')) != (data_layout.ndim()-1)) {
CHECK(origin_data_layout.defined())
<< "sub-channel c in data layout " << data_layout
<< " does not at the final dimension";
// convert it to the original one.
data_layout = origin_data_layout;
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, origin_data_layout);
} else {
for (Layout::LayoutDim axis : data_layout) {
if (Layout::is_subdim(axis) && axis != 'c') {
CHECK(origin_data_layout.defined())
<< "sub-axis other than c appears in data layout " << data_layout;
// convert it to the original one.
data_layout = origin_data_layout;
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, origin_data_layout);
break;
}
}
}
// decide the param layout
if (data_layout.defined()) {
auto channel_block = data_layout.subsizeof('C');
if (channel_block > 0) {
param_layout = param_layout.split('C', 1, channel_block);
}
}
}
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, data_layout);
NNVM_ASSIGN_LAYOUT(*in_layouts, 1, param_layout);
NNVM_ASSIGN_LAYOUT(*in_layouts, 2, param_layout);
NNVM_ASSIGN_LAYOUT(*in_layouts, 3, param_layout);
NNVM_ASSIGN_LAYOUT(*in_layouts, 4, param_layout);
NNVM_ASSIGN_LAYOUT(*out_layouts, 0, data_layout);
NNVM_ASSIGN_LAYOUT(*out_layouts, 1, param_layout);
NNVM_ASSIGN_LAYOUT(*out_layouts, 2, param_layout);
return true; return true;
} }
...@@ -238,6 +307,7 @@ axis to be the last item in the input shape. ...@@ -238,6 +307,7 @@ axis to be the last item in the input shape.
.add_arguments(BatchNormParam::__FIELDS__()) .add_arguments(BatchNormParam::__FIELDS__())
.set_attr_parser(ParamParser<BatchNormParam>) .set_attr_parser(ParamParser<BatchNormParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BatchNormParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BatchNormParam>)
.set_attr<FInferLayout>("FInferLayout", BatchNormInferLayout)
.set_num_inputs(5) .set_num_inputs(5)
.set_num_outputs(3) .set_num_outputs(3)
.set_attr<FInferShape>("FInferShape", BatchNormInferShape) .set_attr<FInferShape>("FInferShape", BatchNormInferShape)
...@@ -275,6 +345,7 @@ NNVM_REGISTER_OP(softmax) ...@@ -275,6 +345,7 @@ NNVM_REGISTER_OP(softmax)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
...@@ -331,6 +402,7 @@ NNVM_REGISTER_OP(log_softmax) ...@@ -331,6 +402,7 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -388,6 +460,7 @@ NNVM_REGISTER_OP(leaky_relu) ...@@ -388,6 +460,7 @@ NNVM_REGISTER_OP(leaky_relu)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -439,6 +512,30 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs, ...@@ -439,6 +512,30 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs,
return true; return true;
} }
inline bool PReluInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
CHECK_EQ(in_layouts->size(), 2U);
CHECK_EQ(last_in_layouts->size(), 2U);
CHECK_EQ(out_layouts->size(), 1U);
const Layout& data_layout = last_in_layouts->at(0).defined() ?
last_in_layouts->at(0) : in_layouts->at(0);
if (data_layout.defined()) {
CHECK(data_layout.indexof('C') == param.axis && !data_layout.contains('c'))
<< "Channel in data layout " << data_layout
<< " is not at index " << param.axis;
}
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, data_layout);
NNVM_ASSIGN_LAYOUT(*in_layouts, 1, Layout("C"));
NNVM_ASSIGN_LAYOUT(*out_layouts, 0, data_layout);
return true;
}
NNVM_REGISTER_OP(prelu) NNVM_REGISTER_OP(prelu)
.describe(R"code(Parametric version of a Rectified Linear Unit. .describe(R"code(Parametric version of a Rectified Linear Unit.
It accepts two arguments: an input ``x`` and a channelwise slope ``alpha`` It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
...@@ -453,6 +550,7 @@ where :math:`*` is an channelwise multiplication for each sample in the ...@@ -453,6 +550,7 @@ where :math:`*` is an channelwise multiplication for each sample in the
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", PReluInferShape) .set_attr<FInferShape>("FInferShape", PReluInferShape)
.set_attr<FInferLayout>("FInferLayout", PReluInferLayout)
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { .set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "alpha"}; return std::vector<std::string>{"data", "alpha"};
}) })
...@@ -499,6 +597,7 @@ NNVM_REGISTER_OP(pad) ...@@ -499,6 +597,7 @@ NNVM_REGISTER_OP(pad)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", PadInferShape) .set_attr<FInferShape>("FInferShape", PadInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -520,5 +619,94 @@ NNVM_REGISTER_OP(pad) ...@@ -520,5 +619,94 @@ NNVM_REGISTER_OP(pad)
}) })
.set_support_level(1); .set_support_level(1);
// layout transformer
DMLC_REGISTER_PARAMETER(LayoutTransformParam);
inline bool LayoutTransformInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
CHECK_EQ(out_attrs->size(), 1U);
const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed);
const TShape &dshape = (*in_attrs)[0];
if (dshape.ndim() == 0) return false;
const TShape &oshape = ConvertLayout(dshape,
Layout(param.src_layout),
Layout(param.dst_layout));
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
return true;
}
NNVM_REGISTER_OP(__layout_transform__)
.describe(R"code(Transform the input data layout.
For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes
the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
)code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.add_argument("data", "Tensor", "Input data.")
.add_arguments(LayoutTransformParam::__FIELDS__())
.set_attr_parser(ParamParser<LayoutTransformParam>)
.set_attr<FInferShape>("FInferShape", LayoutTransformInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>(
"FInferLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1U);
CHECK_EQ(olayouts->size(), 1U);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, Layout(param.src_layout));
NNVM_ASSIGN_LAYOUT(*olayouts, 0, Layout(param.dst_layout));
return true;
})
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& outputs) {
const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed);
Layout src_layout(param.src_layout);
Layout dst_layout(param.dst_layout);
if (src_layout == dst_layout) {
return Array<Tensor>{ inputs[0] };
} else if (!src_layout.defined() || !dst_layout.defined()) {
LOG(FATAL) << "cannot convert from/to undefined layout";
}
CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " << param.src_layout
<< " to " << param.dst_layout;
return Array<Tensor> {
topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array<Var>& dst_indices) {
std::vector<Expr> dst_to_src_indices;
for (Layout::LayoutDim src_axis : src_layout) {
int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis));
int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis));
Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::is_superdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::is_subdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<Expr>(dst_to_src_indices);
})
};
})
.set_support_level(1);
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <nnvm/layout.h>
#include <nnvm/top/nn.h> #include <nnvm/top/nn.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -40,100 +41,47 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) { ...@@ -40,100 +41,47 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) {
* \param dst_layout target layout * \param dst_layout target layout
* \return shape in target layout * \return shape in target layout
*/ */
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) { inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout& dst_layout) {
if (src_layout == dst_layout) return src; if (src_layout == dst_layout) {
TShape dst = src; return src;
if (src.ndim() == 3) { } else if (!src_layout.defined()) {
switch (src_layout) { LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
case kNCW: break; } else if (!dst_layout.defined()) {
case kNWC: { LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
std::swap(dst[1], dst[2]);
break;
} }
default: {
LOG(FATAL) << "inavlid layout for 3d shape" << src_layout; CHECK(src_layout.convertible(dst_layout)) << "cannot convert from "
} << src_layout << " to " << dst_layout;
}
switch (dst_layout) { TShape dst(dst_layout.ndim());
case kNCW: break; for (size_t i = 0; i < src_layout.ndim(); ++i) {
case kNWC: { Layout::LayoutDim src_dim = src_layout[i];
std::swap(dst[1], dst[2]); if (Layout::is_superdim(src_dim)) {
break; int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_dim));
} int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_dim));
default: { int src_minor_pos = src_layout.indexof(Layout::to_subdim(src_dim));
LOG(FATAL) << "inavlid layout for 3d shape" << dst_layout; int src_factor = src_layout.subsizeof(src_dim);
} int dst_factor = dst_layout.subsizeof(src_dim);
}
} else if (src.ndim() == 4) { uint32_t src_dim_size = src[i];
switch (src_layout) { if (src_minor_pos >= 0) {
case kNCHW: break; CHECK_EQ(src_factor, src[src_minor_pos]) << "src shape " << src
case kNHWC: { << " does not agree with layout " << src_layout;
if (is_weight) { src_dim_size *= src_factor;
dst[2] = src[0];
dst[3] = src[1];
dst[1] = src[2];
dst[0] = src[3];
} else {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
}
break;
}
default: {
LOG(FATAL) << "inavlid layout for 4d shape" << src_layout;
}
}
src = dst;
switch (dst_layout) {
case kNCHW: break;
case kNHWC: {
if (is_weight) {
dst[0] = src[2];
dst[1] = src[3];
dst[2] = src[1];
dst[3] = src[0];
} else {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
}
break;
}
default: {
LOG(FATAL) << "inavlid layout for 4d shape" << dst_layout;
}
}
} else if (src.ndim() == 5) {
switch (src_layout) {
case kNCDHW: break;
case kNDHWC: {
dst[2] = src[1];
dst[3] = src[2];
dst[4] = src[3];
dst[1] = src[4];
break;
}
default: {
LOG(FATAL) << "inavlid layout for 5d shape" << src_layout;
}
}
src = dst;
switch (dst_layout) {
case kNCDHW: break;
case kNDHWC: {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[4];
dst[4] = src[1];
break;
} }
default: {
LOG(FATAL) << "inavlid layout for 5d shape" << dst_layout; dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
CHECK_LE(dst_factor, src_dim_size) << "Converting " << src
<< " from " << src_layout
<< " to " << dst_factor
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
} }
} }
} else {
LOG(FATAL) << "no layout option for " << dst.ndim() << " dimensions";
} }
return dst; return dst;
} }
......
...@@ -30,34 +30,73 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -30,34 +30,73 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs,
TShape dshape = (*in_shape)[0]; TShape dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false; if (dshape.ndim() == 0) return false;
dshape = ConvertLayout(dshape, param.layout, kNCHW);
CHECK_GE(dshape.ndim(), 2U)
<< "Pool2D only support input >= 2-D: input must have height and width";
Layout layout(param.layout);
CHECK(layout.contains('H') && layout.contains('W') &&
!layout.contains('h') && !layout.contains('w'))
<< "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.indexof('H');
const auto widx = layout.indexof('W');
TShape oshape = dshape; TShape oshape = dshape;
CHECK_EQ(dshape.ndim(), 4U) CHECK(param.pool_size[0] <= dshape[hidx] + 2 * param.padding[0])
<< "Pooling: Input data should be 4D"; << "pool size (" << param.pool_size[0] << ") exceeds input (" << dshape[hidx]
CHECK(param.pool_size[0] <= dshape[2] + 2 * param.padding[0]) << " padded to " << (dshape[hidx] + 2*param.padding[0]) << ")";
<< "pool size (" << param.pool_size[0] << ") exceeds input (" << dshape[2] CHECK(param.pool_size[1] <= dshape[widx] + 2 * param.padding[1])
<< " padded to " << (dshape[2] + 2*param.padding[0]) << ")"; << "pool size (" << param.pool_size[1] << ") exceeds input (" << dshape[widx]
CHECK(param.pool_size[1] <= dshape[3] + 2 * param.padding[1]) << " padded to " << (dshape[widx] + 2*param.padding[1]) << ")";
<< "pool size (" << param.pool_size[1] << ") exceeds input (" << dshape[3]
<< " padded to " << (dshape[3] + 2*param.padding[1]) << ")";
if (!param.ceil_mode) { if (!param.ceil_mode) {
oshape[2] = ((dshape[2] + 2 * param.padding[0] - param.pool_size[0]) / oshape[hidx] = ((dshape[hidx] + 2 * param.padding[0] - param.pool_size[0]) /
param.strides[0]) + 1; param.strides[0]) + 1;
oshape[3] = ((dshape[3] + 2 * param.padding[1] - param.pool_size[1]) / oshape[widx] = ((dshape[widx] + 2 * param.padding[1] - param.pool_size[1]) /
param.strides[1]) + 1; param.strides[1]) + 1;
} else { } else {
oshape[2] = ((dshape[2] + 2 * param.padding[0] - param.pool_size[0] + oshape[hidx] = ((dshape[hidx] + 2 * param.padding[0] - param.pool_size[0] +
param.strides[0] - 1) / param.strides[0]) + 1; param.strides[0] - 1) / param.strides[0]) + 1;
oshape[3] = ((dshape[3] + 2 * param.padding[1] - param.pool_size[1] + oshape[widx] = ((dshape[3] + 2 * param.padding[1] - param.pool_size[1] +
param.strides[1] - 1) / param.strides[1]) + 1; param.strides[1] - 1) / param.strides[1]) + 1;
} }
oshape = ConvertLayout(oshape, kNCHW, param.layout);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true; return true;
} }
inline bool Pool2DInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const Pool2DParam &param = nnvm::get<Pool2DParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1);
CHECK_EQ(last_ilayouts->size(), 1);
CHECK_EQ(olayouts->size(), 1);
Layout input = (*ilayouts)[0];
const Layout layout(param.layout);
if (input.defined()) {
CHECK(input.convertible(layout)) << "Invalid input layout " << input;
if (input.indexof('W') != layout.indexof('W') ||
input.indexof('H') != layout.indexof('H') ||
input.contains('w') || input.contains('h')) {
// as long as the index doesn't change for width and height
// pool2d can keep the input layout.
input = layout;
}
} else {
input = layout;
}
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, input);
return true;
}
NNVM_REGISTER_OP(max_pool2d) NNVM_REGISTER_OP(max_pool2d)
.describe(R"code(Max pooling operation for one dimensional data. .describe(R"code(Max pooling operation for one dimensional data.
...@@ -82,8 +121,8 @@ NNVM_REGISTER_OP(max_pool2d) ...@@ -82,8 +121,8 @@ NNVM_REGISTER_OP(max_pool2d)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FInferLayout>("FInferLayout", Pool2DInferLayout)
"FTVMCompute", [](const NodeAttrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed); const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
...@@ -91,11 +130,20 @@ NNVM_REGISTER_OP(max_pool2d) ...@@ -91,11 +130,20 @@ NNVM_REGISTER_OP(max_pool2d)
auto strides = ShapeToArray(param.strides); auto strides = ShapeToArray(param.strides);
auto padding = ShapeToArray(param.padding); auto padding = ShapeToArray(param.padding);
auto ceil_mode = param.ceil_mode; auto ceil_mode = param.ceil_mode;
CHECK(param.layout == kNCHW || param.layout == kNHWC) << "Unsupported layout";
std::string layout = (param.layout == kNCHW ? "NCHW" : "NHWC"); Layout layout(param.layout);
CHECK(layout.convertible(Layout("NCHW")))
<< "max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.indexof('h'), -1) << "max_pool2d does not support input split on height";
CHECK_EQ(layout.indexof('w'), -1) << "max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";
return Array<Tensor>{ return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding, \ topi::nn::pool(inputs[0], pool_size, strides, padding,
topi::nn::kMaxPool, ceil_mode, layout) }; topi::nn::kMaxPool, ceil_mode, layout.name())};
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -144,8 +192,8 @@ NNVM_REGISTER_OP(avg_pool2d) ...@@ -144,8 +192,8 @@ NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FInferLayout>("FInferLayout", Pool2DInferLayout)
"FTVMCompute", [](const NodeAttrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed); const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
...@@ -153,11 +201,20 @@ NNVM_REGISTER_OP(avg_pool2d) ...@@ -153,11 +201,20 @@ NNVM_REGISTER_OP(avg_pool2d)
auto strides = ShapeToArray(param.strides); auto strides = ShapeToArray(param.strides);
auto padding = ShapeToArray(param.padding); auto padding = ShapeToArray(param.padding);
auto ceil_mode = param.ceil_mode; auto ceil_mode = param.ceil_mode;
CHECK(param.layout == kNCHW || param.layout == kNHWC) << "Unsupported layout";
std::string layout = (param.layout == kNCHW ? "NCHW" : "NHWC"); Layout layout(param.layout);
CHECK(layout.convertible(Layout("NCHW")))
<< "avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.indexof('h'), -1) << "avg_pool2d does not support input split on height";
CHECK_EQ(layout.indexof('w'), -1) << "avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";
return Array<Tensor>{ return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding, \ topi::nn::pool(inputs[0], pool_size, strides, padding,
topi::nn::kAvgPool, ceil_mode, layout) }; topi::nn::kAvgPool, ceil_mode, layout.name())};
}) })
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
...@@ -169,19 +226,63 @@ DMLC_REGISTER_PARAMETER(GlobalPool2DParam); ...@@ -169,19 +226,63 @@ DMLC_REGISTER_PARAMETER(GlobalPool2DParam);
inline bool GlobalPool2DInferShape(const nnvm::NodeAttrs& attrs, inline bool GlobalPool2DInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape, std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) { std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed); const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(in_shape->size(), 1U);
CHECK_EQ(out_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[0]; TShape dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false; if (dshape.ndim() == 0) return false;
dshape = ConvertLayout(dshape, param.layout, kNCHW);
CHECK_GE(dshape.ndim(), 2U)
<< "Pool2D only support input >= 2-D: input must have height and width";
Layout layout(param.layout);
CHECK(layout.contains('H') && layout.contains('W') &&
!layout.contains('h') && !layout.contains('w'))
<< "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.indexof('H');
const auto widx = layout.indexof('W');
TShape oshape = dshape; TShape oshape = dshape;
oshape[2] = oshape[3] = 1; oshape[hidx] = oshape[widx] = 1;
oshape = ConvertLayout(oshape, kNCHW, param.layout);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true; return true;
} }
inline bool GlobalPool2DInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const GlobalPool2DParam &param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1);
CHECK_EQ(last_ilayouts->size(), 1);
CHECK_EQ(olayouts->size(), 1);
Layout input = (*ilayouts)[0];
const Layout layout(param.layout);
if (input.defined()) {
CHECK(input.convertible(layout)) << "Invalid input layout " << input;
if (input.indexof('W') != layout.indexof('W') ||
input.indexof('H') != layout.indexof('H') ||
input.contains('w') || input.contains('h')) {
// as long as the index doesn't change for width and height
// pool2d can keep the input layout.
input = layout;
}
} else {
input = layout;
}
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, input);
return true;
}
NNVM_REGISTER_OP(global_max_pool2d) NNVM_REGISTER_OP(global_max_pool2d)
.describe(R"code(Global max pooling operation for 2D data. .describe(R"code(Global max pooling operation for 2D data.
...@@ -197,15 +298,26 @@ NNVM_REGISTER_OP(global_max_pool2d) ...@@ -197,15 +298,26 @@ NNVM_REGISTER_OP(global_max_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", GlobalPool2DInferLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed); const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(param.layout, kNCHW) Layout layout(param.layout);
<< "global_max_pool2d currently only supports NCHW layout"; CHECK(layout.convertible(Layout("NCHW")))
<< "global_max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.indexof('h'), -1)
<< "global_max_pool2d does not support input split on height";
CHECK_EQ(layout.indexof('w'), -1)
<< "global_max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";
return Array<Tensor>{ return Array<Tensor>{
topi::nn::global_pool(inputs[0], topi::nn::kMaxPool) }; topi::nn::global_pool(inputs[0], topi::nn::kMaxPool, layout.name()) };
}) })
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
...@@ -227,15 +339,26 @@ NNVM_REGISTER_OP(global_avg_pool2d) ...@@ -227,15 +339,26 @@ NNVM_REGISTER_OP(global_avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", GlobalPool2DInferLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed); const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(param.layout, kNCHW) Layout layout(param.layout);
<< "global_avg_pool2d currently only supports NCHW layout"; CHECK(layout.convertible(Layout("NCHW")))
<< "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.indexof('h'), -1)
<< "global_avg_pool2d does not support input split on height";
CHECK_EQ(layout.indexof('w'), -1)
<< "global_avg_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";
return Array<Tensor>{ return Array<Tensor>{
topi::nn::global_pool(inputs[0], topi::nn::kAvgPool) }; topi::nn::global_pool(inputs[0], topi::nn::kAvgPool, layout.name()) };
}) })
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
......
...@@ -19,6 +19,7 @@ DMLC_REGISTER_PARAMETER(UpSamplingParam); ...@@ -19,6 +19,7 @@ DMLC_REGISTER_PARAMETER(UpSamplingParam);
inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs, inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape, std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) { std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed); const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(in_shape->size(), 1U);
CHECK_EQ(out_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U);
...@@ -33,6 +34,19 @@ inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs, ...@@ -33,6 +34,19 @@ inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool UpsamplingLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
CHECK_EQ(in_layouts->size(), 1U);
CHECK_EQ(out_layouts->size(), 1U);
const Layout layout(param.layout);
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, layout);
NNVM_ASSIGN_LAYOUT(*out_layouts, 0, layout);
return true;
}
NNVM_REGISTER_OP(upsampling) NNVM_REGISTER_OP(upsampling)
.describe(R"(Perform nearest neighbor upsampling to input array. .describe(R"(Perform nearest neighbor upsampling to input array.
...@@ -46,6 +60,7 @@ NNVM_REGISTER_OP(upsampling) ...@@ -46,6 +60,7 @@ NNVM_REGISTER_OP(upsampling)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>)
.set_attr<FInferShape>("FInferShape", UpSamplingInferShape) .set_attr<FInferShape>("FInferShape", UpSamplingInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", UpsamplingLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
......
...@@ -203,6 +203,13 @@ inline std::string attr_assign_error_msg(const NodeAttrs& attrs, ...@@ -203,6 +203,13 @@ inline std::string attr_assign_error_msg(const NodeAttrs& attrs,
} \ } \
} }
#define NNVM_ASSIGN_LAYOUT(outputs, index, layout) \
{ \
if (layout.defined()) { \
(outputs)[index] = layout; \
} \
}
/*! /*!
* \brief macro assign rhs shape to lhs * \brief macro assign rhs shape to lhs
* Use macro so we can see the error file more clearly * Use macro so we can see the error file more clearly
...@@ -253,6 +260,14 @@ inline bool ZeroShape(const NodeAttrs& attrs, ...@@ -253,6 +260,14 @@ inline bool ZeroShape(const NodeAttrs& attrs,
} }
} }
// do not infer layout
inline bool ZeroLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
return true;
}
// simply assign output shape or type from input // simply assign output shape or type from input
template<typename AttrType, int in_index, int out_index> template<typename AttrType, int in_index, int out_index>
inline bool AssignOutputAttr(const NodeAttrs& attrs, inline bool AssignOutputAttr(const NodeAttrs& attrs,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <nnvm/compiler/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h> #include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include <nnvm/top/nn.h>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/broadcast.h" #include "topi/broadcast.h"
...@@ -74,6 +75,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example. ...@@ -74,6 +75,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>)
.set_attr<FInferShape>("FInferShape", BroadcastToInferShape) .set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -115,7 +117,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, ...@@ -115,7 +117,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
} else { } else {
CHECK(l == 1 || r == 1) CHECK(l == 1 || r == 1)
<< "operands could not be broadcast together with shapes " << "operands could not be broadcast together with shapes "
<< lhs << " " << rhs; << lhs << " " << rhs << ", l=" << l << ", r=" << r;
out[i] = std::max(l, r); out[i] = std::max(l, r);
} }
} else { } else {
...@@ -126,6 +128,77 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, ...@@ -126,6 +128,77 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool BinaryBroadcastInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), 2U);
CHECK_EQ(olayouts->size(), 1U);
Layout lhs = (*ilayouts)[0];
Layout rhs = (*ilayouts)[1];
Layout out(Layout::Undef());
if (lhs.defined() && rhs.defined()) {
if (lhs == rhs) {
NNVM_ASSIGN_LAYOUT(*olayouts, 0, lhs);
return true;
}
// For example, NCHW <-> CHW, N16nCH16cW <-> HCW16c, etc, are broadcast-convertible
// because as the definition, CHW can broadcast with NCHW.
// For the second case, we can convert HCW16c to CH16cW then it can broadcast with N16nCH16cW.
// But CNHW <-> CHW, NCHW16n <-> CHW are not,
// because not matter how we adjust the layout of 'CHW',
// we can never have an 'N' between 'C' and "HW".
size_t l_start = 0, r_start = 0;
size_t l = 0, r = 0;
bool find_first_match = false;
while (l < lhs.ndim() && r < rhs.ndim()) {
if (!rhs.contains(Layout::to_superdim(lhs[l]))) {
CHECK(!find_first_match) << lhs << " and " << rhs << " are not broadcast-convertible";
l_start = ++l;
} else if (!lhs.contains(Layout::to_superdim(rhs[r]))) {
CHECK(!find_first_match) << lhs << " and " << rhs << " are not broadcast-convertible";
r_start = ++r;
} else {
find_first_match = true;
++l; ++r;
}
}
if (l_start > 0 && r_start > 0) {
LOG(FATAL) << lhs << " and " << rhs << " are not broadcast-convertible";
} else if (l_start > 0) {
rhs = lhs.sublayout(l_start, lhs.ndim()-l_start);
out = lhs;
} else if (r_start > 0) {
lhs = rhs.sublayout(r_start, rhs.ndim()-r_start);
out = rhs;
} else {
// prior to keep left layout
rhs = lhs;
out = lhs;
}
} else if (lhs.defined()) {
const Layout& last_lhs = last_ilayouts->at(0);
if (last_lhs.defined()) {
CHECK(lhs.convertible(last_lhs)) << "current lhs layout " << lhs
<< " cannot be converted to the original one " << last_lhs;
lhs = last_lhs;
// cannot decide output layout
}
} else if (rhs.defined()) {
const Layout& last_rhs = last_ilayouts->at(1);
if (last_rhs.defined()) {
CHECK(rhs.convertible(last_rhs)) << "current rhs layout " << rhs
<< " cannot be converted to the original one " << last_rhs;
rhs = last_rhs;
// cannot decide output layout
}
}
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, lhs);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, rhs);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, out);
return true;
}
#define NNVM_REGISTER_BINARY_BROADCAST_OP(name) \ #define NNVM_REGISTER_BINARY_BROADCAST_OP(name) \
NNVM_REGISTER_OP(name) \ NNVM_REGISTER_OP(name) \
...@@ -133,6 +206,8 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, ...@@ -133,6 +206,8 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \ .set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
BinaryBroadcastInferLayout) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \ [](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
......
...@@ -333,6 +333,7 @@ NNVM_REGISTER_INIT_OP(full) ...@@ -333,6 +333,7 @@ NNVM_REGISTER_INIT_OP(full)
.add_arguments(InitOpWithScalarParam::__FIELDS__()) .add_arguments(InitOpWithScalarParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
NNVM_REGISTER_INIT_OP(zeros) NNVM_REGISTER_INIT_OP(zeros)
...@@ -345,6 +346,7 @@ NNVM_REGISTER_INIT_OP(zeros) ...@@ -345,6 +346,7 @@ NNVM_REGISTER_INIT_OP(zeros)
.add_arguments(InitOpParam::__FIELDS__()) .add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
NNVM_REGISTER_INIT_OP(ones) NNVM_REGISTER_INIT_OP(ones)
...@@ -357,6 +359,7 @@ NNVM_REGISTER_INIT_OP(ones) ...@@ -357,6 +359,7 @@ NNVM_REGISTER_INIT_OP(ones)
.add_arguments(InitOpParam::__FIELDS__()) .add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
// full_like // full_like
...@@ -693,6 +696,7 @@ Example:: ...@@ -693,6 +696,7 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -41,6 +41,31 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, ...@@ -41,6 +41,31 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool DotInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 2U);
CHECK_EQ(olayouts->size(), 1U);
const Layout& lhs = last_ilayouts->at(0).defined() ? last_ilayouts->at(0)
: ilayouts->at(0);
const Layout& rhs = last_ilayouts->at(1).defined() ? last_ilayouts->at(1)
: ilayouts->at(1);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, lhs);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, rhs);
if (lhs.ndim() > 1 && rhs.ndim() > 1) {
// concat lhs and rhs layout
const Layout& lhs_out = param.transpose_a ? lhs.reverse() : lhs;
const Layout& rhs_out = param.transpose_b ? rhs.reverse() : rhs;
Layout out = std::move(lhs_out.sublayout(0, lhs_out.ndim()-1) +
rhs_out.sublayout(1, rhs_out.ndim()-1));
NNVM_ASSIGN_LAYOUT(*olayouts, 0, out);
}
return true;
}
NNVM_REGISTER_OP(matmul) NNVM_REGISTER_OP(matmul)
.describe(R"doc(Matrix multiplication of two arrays. .describe(R"doc(Matrix multiplication of two arrays.
...@@ -67,6 +92,7 @@ NNVM_REGISTER_OP(matmul) ...@@ -67,6 +92,7 @@ NNVM_REGISTER_OP(matmul)
.add_argument("rhs", "NDArray-or-Symbol", "The second input") .add_argument("rhs", "NDArray-or-Symbol", "The second input")
.set_attr<FInferShape>("FInferShape", DotShape) .set_attr<FInferShape>("FInferShape", DotShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FInferLayout>("FInferLayout", DotInferLayout)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
......
...@@ -111,6 +111,8 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) { ...@@ -111,6 +111,8 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \ .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.set_attr<FInferShape>("FInferShape", ReduceShape) \ .set_attr<FInferShape>("FInferShape", ReduceShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_num_inputs(1) \ .set_num_inputs(1) \
.set_num_outputs(1) .set_num_outputs(1)
......
...@@ -45,6 +45,15 @@ This is an experimental operator. ...@@ -45,6 +45,15 @@ This is an experimental operator.
return Array<Tensor>{ topi::identity(inputs[1]) }; return Array<Tensor>{ topi::identity(inputs[1]) };
}) })
.set_attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferLayout>(
"FInferLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
NNVM_ASSIGN_LAYOUT(*in_layouts, 1, (*in_layouts)[0]);
NNVM_ASSIGN_LAYOUT(*out_layouts, 0, (*in_layouts)[0]);
return true;
})
.set_attr<FInplaceOption>( .set_attr<FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) { "FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{1, 0}}; return std::vector<std::pair<int, int> >{{1, 0}};
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <nnvm/compiler/util.h> #include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include <cctype> #include <cctype>
#include <sstream>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/nn/flatten.h" #include "topi/nn/flatten.h"
...@@ -63,6 +64,7 @@ Example:: ...@@ -63,6 +64,7 @@ Example::
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", FlattenInferShape) .set_attr<FInferShape>("FInferShape", FlattenInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
...@@ -119,6 +121,22 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs, ...@@ -119,6 +121,22 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs,
return dshape.Size() != 0; return dshape.Size() != 0;
} }
inline bool ConcatenateInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);
for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}
return true;
}
NNVM_REGISTER_OP(concatenate) NNVM_REGISTER_OP(concatenate)
.describe(R"code(Joins input arrays along a given axis. .describe(R"code(Joins input arrays along a given axis.
...@@ -156,6 +174,7 @@ Example:: ...@@ -156,6 +174,7 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>)
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape) .set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", ConcatenateInferLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -177,7 +196,8 @@ inline bool ExpandDimsInferShape(const NodeAttrs& attrs, ...@@ -177,7 +196,8 @@ inline bool ExpandDimsInferShape(const NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(in_shape->size(), 1U);
const TShape& dshape = in_shape->at(0); const TShape& dshape = in_shape->at(0);
int ndim = static_cast<int>(dshape.ndim()); int ndim = static_cast<int>(dshape.ndim());
CHECK(param.axis >= -ndim - 1 && param.axis <= ndim); CHECK(param.axis >= -ndim - 1 && param.axis <= ndim)
<< "with axis = " << param.axis << " ndim = " << ndim;
int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis; int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis;
std::vector<dim_t> oshape; std::vector<dim_t> oshape;
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
...@@ -198,7 +218,7 @@ NNVM_REGISTER_OP(expand_dims) ...@@ -198,7 +218,7 @@ NNVM_REGISTER_OP(expand_dims)
.describe(R"code(Inserts a new axis of size 1 into the array shape .describe(R"code(Inserts a new axis of size 1 into the array shape
For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1, num_newaxis=5)`` For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1, num_newaxis=5)``
will return a new array with shape ``(2,5,3,4)``. will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input tensor") .add_argument("data", "Tensor", "Input tensor")
...@@ -207,6 +227,7 @@ will return a new array with shape ``(2,5,3,4)``. ...@@ -207,6 +227,7 @@ will return a new array with shape ``(2,5,3,4)``.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>)
.set_attr<FInferShape>("FInferShape", ExpandDimsInferShape) .set_attr<FInferShape>("FInferShape", ExpandDimsInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -249,6 +270,8 @@ Examples:: ...@@ -249,6 +270,8 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<IndicatorParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<IndicatorParam>)
.set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>) .set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
// never transform layout of the second input array.
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FGradient>( .set_attr<FGradient>(
...@@ -345,6 +368,7 @@ along which to split the array. ...@@ -345,6 +368,7 @@ along which to split the array.
.set_attr_parser(SplitParamParser) .set_attr_parser(SplitParamParser)
.set_attr<FInferShape>("FInferShape", SplitInferShape) .set_attr<FInferShape>("FInferShape", SplitInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, -1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, -1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(SplitNumOutputs) .set_num_outputs(SplitNumOutputs)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -387,6 +411,7 @@ NNVM_REGISTER_OP(cast) ...@@ -387,6 +411,7 @@ NNVM_REGISTER_OP(cast)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<CastParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<CastParam>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", CastInferType) .set_attr<FInferType>("FInferType", CastInferType)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(1); .set_support_level(1);
...@@ -539,6 +564,7 @@ The significance of each is explained below: ...@@ -539,6 +564,7 @@ The significance of each is explained below:
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReshapeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReshapeParam>)
.set_attr<FInferShape>("FInferShape", ReshapeInferShape) .set_attr<FInferShape>("FInferShape", ReshapeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -578,6 +604,8 @@ the input array into an output array with the same shape as the second input arr ...@@ -578,6 +604,8 @@ the input array into an output array with the same shape as the second input arr
return true; return true;
}) })
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
// never transform layout of the second input array.
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -660,6 +688,7 @@ Examples:: ...@@ -660,6 +688,7 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape) .set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -680,7 +709,7 @@ Examples:: ...@@ -680,7 +709,7 @@ Examples::
}) })
.set_support_level(1); .set_support_level(1);
// tranpose // transpose
DMLC_REGISTER_PARAMETER(TransposeParam); DMLC_REGISTER_PARAMETER(TransposeParam);
inline bool TransposeShape(const nnvm::NodeAttrs& attrs, inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
...@@ -708,6 +737,39 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, ...@@ -708,6 +737,39 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool TransposeInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1U);
CHECK_EQ(olayouts->size(), 1U);
const Layout& input = last_ilayouts->at(0).defined()
? last_ilayouts->at(0)
: ilayouts->at(0);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input);
if (input.defined()) {
std::ostringstream new_layout;
if (param.axes.ndim() == 0) {
for (size_t i = 0; i < input.ndim(); ++i) {
new_layout << input.at(input.ndim() - 1 - i);
}
} else {
CHECK_EQ(input.ndim(), param.axes.ndim());
for (size_t i = 0; i < input.ndim(); ++i) {
CHECK(param.axes[i] < input.ndim());
new_layout << input.at(param.axes[i]);
}
}
NNVM_ASSIGN_LAYOUT(*olayouts, 0, Layout(new_layout.str()));
}
return true;
}
NNVM_REGISTER_OP(transpose) NNVM_REGISTER_OP(transpose)
.describe(R"code(Permutes the dimensions of an array. .describe(R"code(Permutes the dimensions of an array.
...@@ -743,6 +805,7 @@ Examples:: ...@@ -743,6 +805,7 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<TransposeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<TransposeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", TransposeShape) .set_attr<nnvm::FInferShape>("FInferShape", TransposeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", TransposeInferLayout)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(4) .set_support_level(4)
......
"""Unittest cases for AlterOpLayout pass"""
from nnvm import symbol as sym
from nnvm.compiler import graph_attr
from nnvm.top import registry as reg
import nnvm.graph as graph
def get_layouts(g):
ldict = {}
vlayout = g.json_attr("layout")
entry_ptr = g.index.entry_ptr
for i, n in enumerate(g.index.nodes):
begin, end = entry_ptr[i], entry_ptr[i + 1]
ldict[n["name"]] = vlayout[begin:end]
return ldict
def test_alter_conv2d_layout():
data = sym.Variable("data", shape=(1, 32, 512, 512))
conv = sym.conv2d(data, name="conv", channels=16,
kernel_size=(3,3), padding=(1,1),
use_bias=False, layout="NCHW")
relu = sym.relu(conv, name="relu")
flatten = sym.flatten(relu, name="flatten")
softmax = sym.softmax(flatten, name="softmax")
g = graph.create(softmax)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
layouts_origin = get_layouts(g)
@reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos):
new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs["layout"] = "NCHW16c"
new_attrs["kernel_layout"] = "NCHW16c"
new_attrs["name"] = "conv_alter"
return sym.conv2d(inputs[0], inputs[1], **new_attrs)
g = g.apply("AlterOpLayout")
layouts = get_layouts(g)
# check copy layouts
for node in ["data", "relu", "flatten", "softmax", "conv_weight"]:
assert(layouts[node] == layouts_origin[node])
assert(layouts["conv_alter"] == layouts_origin["conv"])
if __name__ == "__main__":
test_alter_conv2d_layout()
...@@ -5,9 +5,10 @@ import nnvm.symbol as sym ...@@ -5,9 +5,10 @@ import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
def get_sym(layout, channels): def get_sym(layout, kernel_layout, channels):
data = sym.Variable(name="data") data = sym.Variable(name="data")
data = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1), layout=layout, use_bias=True) data = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1),
layout=layout, kernel_layout=kernel_layout, use_bias=True)
data = sym.max_pool2d(data=data, pool_size=(2, 2), strides=(2, 2), layout=layout) data = sym.max_pool2d(data=data, pool_size=(2, 2), strides=(2, 2), layout=layout)
data = sym.upsampling(data=data, scale=2, layout=layout) data = sym.upsampling(data=data, scale=2, layout=layout)
softmax_axis = 1 softmax_axis = 1
...@@ -31,8 +32,8 @@ def build_and_run(sym, params, data, out_shape): ...@@ -31,8 +32,8 @@ def build_and_run(sym, params, data, out_shape):
def test_nhwc(): def test_nhwc():
data_shape = (1, 3, 224, 224) data_shape = (1, 3, 224, 224)
out_channel = 8 out_channel = 8
nchw_sym = get_sym("NCHW", out_channel) nchw_sym = get_sym("NCHW", "OIHW", out_channel)
nhwc_sym = get_sym("NHWC", out_channel) nhwc_sym = get_sym("NHWC", "HWIO", out_channel)
conv_weight = np.random.uniform(-1, 1, (out_channel, 3, 3, 3)).astype(np.float32) conv_weight = np.random.uniform(-1, 1, (out_channel, 3, 3, 3)).astype(np.float32)
conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32) conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32)
nchw_params = { nchw_params = {
......
import nnvm
import nnvm.symbol as sym
import nnvm.graph as graph
from nnvm.compiler import graph_attr
# Level 1
def correct_layout(g, layout=None):
if isinstance(g, nnvm.symbol.Symbol):
g = graph.create(g)
if layout:
graph_attr.set_layout_inputs(g, layout)
g = g.apply("CorrectLayout")
ldict = {}
vlayout = g.json_attr("layout")
entry_ptr = g.index.entry_ptr
for i, n in enumerate(g.index.nodes):
begin, end = entry_ptr[i], entry_ptr[i + 1]
ldict[n["name"]] = vlayout[begin:end]
return g, ldict
def test_dense():
x = sym.Variable("data", shape=(10, 20))
y = sym.dense(x, units=30, name="fc")
g, ldict = correct_layout(y, "HW")
assert(ldict["data"][0] == "HW")
assert(ldict["fc"][0] == "HW")
assert(ldict["fc_bias"][0] == "__undef__")
# second pass will insert layout transform
_, ldict = correct_layout(g, "HW16w")
assert(ldict["data"][0] == "HW16w")
assert(ldict["data_HW"][0] == "HW")
assert(ldict["fc"][0] == "HW")
assert(ldict["fc_bias"][0] == "__undef__")
def test_matmul():
a = sym.Variable("a", shape=(10, 20))
b = sym.Variable("b", shape=(20, 30))
c = sym.matmul(a, b, name="matmul")
g, ldict = correct_layout(c, {"a" : "HW", "b" : "WC"})
assert(ldict["a"][0] == "HW")
assert(ldict["b"][0] == "WC")
assert(ldict["matmul"][0] == "HC")
# second pass will insert layout transform
_, ldict = correct_layout(g, {"a" : "HW16w", "b" : "WC16c"})
assert(ldict["a"][0] == "HW16w")
assert(ldict["a_HW"][0] == "HW")
assert(ldict["b"][0] == "WC16c")
assert(ldict["b_WC"][0] == "WC")
assert(ldict["matmul"][0] == "HC")
a = sym.Variable("a", shape=(20, 10))
c = sym.matmul(a, b, name="matmul", transpose_a=True)
g, ldict = correct_layout(c, {"a" : "HW", "b" : "HC"})
assert(ldict["a"][0] == "HW")
assert(ldict["b"][0] == "HC")
assert(ldict["matmul"][0] == "WC")
b = sym.Variable("b", shape=(30, 20))
c = sym.matmul(a, b, name="matmul", transpose_b=True)
g, ldict = correct_layout(c, {"a" : "HW", "b" : "CW"})
assert(ldict["a"][0] == "HW")
assert(ldict["b"][0] == "CW")
assert(ldict["matmul"][0] == "HC")
a = sym.Variable("a", shape=(20, 10))
b = sym.Variable("b", shape=(30, 20))
c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
g, ldict = correct_layout(c, {"a" : "HW", "b" : "CH"})
assert(ldict["a"][0] == "HW")
assert(ldict["b"][0] == "CH")
assert(ldict["matmul"][0] == "WC")
def test_concatenate():
x1 = sym.Variable("x", shape=(10, 20))
x2 = sym.Variable("y", shape=(10, 30))
z = sym.concatenate(x1, x2, name="concat")
g, ldict = correct_layout(z, {"x": "HW", "y": "HW"})
assert(ldict["x"][0] == "HW")
assert(ldict["y"][0] == "HW")
assert(ldict["concat"][0] == "__undef__")
# second pass will insert layout transform
_, ldict = correct_layout(g, {"x": "HW16w", "y": "HW16w"})
assert(ldict["x"][0] == "HW16w")
assert(ldict["y"][0] == "HW16w")
assert(ldict["x_HW"][0] == "HW")
assert(ldict["y_HW"][0] == "HW")
assert(ldict["concat"][0] == "__undef__")
def test_expand_dims():
x = sym.Variable("x", shape=(10, 20))
y = sym.expand_dims(x, axis=1, name="y")
g, ldict = correct_layout(y, "HW")
assert(ldict["x"][0] == "HW")
assert(ldict["y"][0] == "__undef__")
# second pass will insert layout transform
_, ldict = correct_layout(g, "HW16w")
assert(ldict["x"][0] == "HW16w")
assert(ldict["x_HW"][0] == "HW")
assert(ldict["y"][0] == "__undef__")
def test_split():
x = sym.Variable("x", shape=(10, 20))
y = sym.split(x, indices_or_sections=[11], name="y")
g, ldict = correct_layout(y, "HW")
assert(ldict["x"][0] == "HW")
assert(ldict["y"][0] == "__undef__")
# second pass will insert layout transform
_, ldict = correct_layout(g, "HW16w")
assert(ldict["x"][0] == "HW16w")
assert(ldict["x_HW"][0] == "HW")
assert(ldict["y"][0] == "__undef__")
def test_batchnorm():
x = sym.Variable("data", shape=(10, 20, 30, 40))
y = sym.batch_norm(x, axis=1, epsilon=2e-5, name="bn")
g, ldict = correct_layout(y, "NCHW")
assert(ldict["data"][0] == "NCHW")
assert(ldict["bn"][0] == "NCHW")
assert(ldict["bn"][1] == "C")
assert(ldict["bn"][2] == "C")
assert(ldict["bn_beta"][0] == "C")
assert(ldict["bn_gamma"][0] == "C")
assert(ldict["bn_moving_mean"][0] == "C")
assert(ldict["bn_moving_var"][0] == "C")
# batch_norm can deal with sub-dim of C at the last dim.
g, ldict = correct_layout(g, "NCHW16c")
assert(ldict["data"][0] == "NCHW16c")
assert(ldict["bn"][0] == "NCHW16c")
assert(ldict["bn"][1] == "C16c")
assert(ldict["bn"][2] == "C16c")
assert(ldict["bn_beta"][0] == "C")
assert(ldict["bn_beta_C16c"][0] == "C16c")
assert(ldict["bn_gamma"][0] == "C")
assert(ldict["bn_gamma_C16c"][0] == "C16c")
assert(ldict["bn_moving_mean"][0] == "C")
assert(ldict["bn_moving_mean_C16c"][0] == "C16c")
assert(ldict["bn_moving_var"][0] == "C")
assert(ldict["bn_moving_var_C16c"][0] == "C16c")
# but for other layout, it does a layout transform for data
g, ldict = correct_layout(g, "NCH16cW")
assert(ldict["data"][0] == "NCH16cW")
assert(ldict["data_NCHW16c"][0] == "NCHW16c")
assert(ldict["bn"][0] == "NCHW16c")
assert(ldict["bn"][1] == "C16c")
assert(ldict["bn"][2] == "C16c")
assert(ldict["bn_beta"][0] == "C")
assert(ldict["bn_beta_C16c"][0] == "C16c")
assert(ldict["bn_gamma"][0] == "C")
assert(ldict["bn_gamma_C16c"][0] == "C16c")
assert(ldict["bn_moving_mean"][0] == "C")
assert(ldict["bn_moving_mean_C16c"][0] == "C16c")
assert(ldict["bn_moving_var"][0] == "C")
assert(ldict["bn_moving_var_C16c"][0] == "C16c")
def test_flatten():
x = sym.Variable("x", shape=(10, 20, 10, 10))
y = sym.flatten(x, name="y")
g, ldict = correct_layout(y, "NCHW")
assert(ldict["x"][0] == "NCHW")
assert(ldict["y"][0] == "__undef__")
# second pass will insert layout transform
_, ldict = correct_layout(g, "NCHW16c")
assert(ldict["x"][0] == "NCHW16c")
assert(ldict["x_NCHW"][0] == "NCHW")
assert(ldict["y"][0] == "__undef__")
# Level 2
def test_conv2d():
x = sym.Variable("data", shape=(1, 32, 512, 512))
y = sym.conv2d(x, name="conv", channels=12,
kernel_size=(3,3), padding=(1,1), layout="NCHW")
_, ldict = correct_layout(y)
assert(ldict["data"][0] == "NCHW")
assert(ldict["conv_weight"][0] == "OIHW")
assert(ldict["conv_bias"][0] == "C")
assert(ldict["conv"][0] == "NCHW")
y = sym.conv2d(x, name="conv", channels=12,
kernel_size=(3,3), padding=(1,1), layout="NCHW16c",
kernel_layout="OIHW16i16o", out_layout="NCHW8c")
_, ldict = correct_layout(y)
assert(ldict["data"][0] == "NCHW16c")
assert(ldict["conv_weight"][0] == "OIHW16i16o")
assert(ldict["conv_bias"][0] == "C8c")
assert(ldict["conv"][0] == "NCHW8c")
y = sym.conv2d(x, name="conv", channels=12,
kernel_size=(3,3), padding=(1,1), layout="N16cHWC")
_, ldict = correct_layout(y)
assert(ldict["data"][0] == "N16cHWC")
assert(ldict["conv_weight"][0] == "OIHW")
assert(ldict["conv_bias"][0] == "16cC")
assert(ldict["conv"][0] == "N16cHWC")
def test_conv2d_transpose():
x = sym.Variable("data", shape=(1, 32, 512, 512))
y = sym.conv2d_transpose(x, name="conv", channels=12,
kernel_size=(3,3), padding=(1,1), layout="NCHW")
_, ldict = correct_layout(y)
assert(ldict["data"][0] == "NCHW")
assert(ldict["conv_weight"][0] == "OIHW")
assert(ldict["conv_bias"][0] == "C")
assert(ldict["conv"][0] == "NCHW")
def test_max_pool2d():
x = sym.Variable("data", shape=(1, 32, 512, 512))
y = sym.max_pool2d(x, name="pool", pool_size=(3,3),
padding=(1,1), layout="NCHW")
g, ldict = correct_layout(y)
assert(ldict["data"][0] == "NCHW")
assert(ldict["pool"][0] == "NCHW")
# if index of H and W remain the same,
# pool2d does not convert the layout.
g, ldict = correct_layout(g, "NCHW16c")
assert(ldict["data"][0] == "NCHW16c")
assert(ldict["pool"][0] == "NCHW16c")
# for other layout it requires a layout transform.
g, ldict = correct_layout(g, "NHWC")
assert(ldict["data"][0] == "NHWC")
assert(ldict["data_NCHW"][0] == "NCHW")
assert(ldict["pool"][0] == "NCHW")
def test_global_pool2d():
x = sym.Variable("data", shape=(1, 32, 512, 512))
y = sym.global_max_pool2d(x, name="pool", layout="NCHW")
g, ldict = correct_layout(y)
assert(ldict["data"][0] == "NCHW")
assert(ldict["pool"][0] == "NCHW")
# if index of H and W remain the same,
# pool2d does not convert the layout.
g, ldict = correct_layout(g, "NCHW16c")
assert(ldict["data"][0] == "NCHW16c")
assert(ldict["pool"][0] == "NCHW16c")
# for other layout it requires a layout transform.
g, ldict = correct_layout(g, "NHWC")
assert(ldict["data"][0] == "NHWC")
assert(ldict["data_NCHW"][0] == "NCHW")
assert(ldict["pool"][0] == "NCHW")
# Level 3
def test_reshape():
x = sym.Variable("x", shape=(4,))
y = sym.reshape(x, shape=(2,2), name="y")
g, ldict = correct_layout(y, "C")
assert(ldict["x"][0] == "C")
assert(ldict["y"][0] == "__undef__")
# second pass will insert layout transform
g, ldict = correct_layout(g, "C16c")
assert(ldict["x"][0] == "C16c")
assert(ldict["x_C"][0] == "C")
assert(ldict["y"][0] == "__undef__")
def test_transpose():
x = sym.Variable("x", shape=(1, 32, 512, 512))
y = sym.transpose(x, name="y", axes=(0, 2, 3, 1))
g, ldict = correct_layout(y, "NCHW")
assert(ldict["x"][0] == "NCHW")
assert(ldict["y"][0] == "NHWC")
# second pass will insert layout transform
g, ldict = correct_layout(g, "NCHW16c")
assert(ldict["x"][0] == "NCHW16c")
assert(ldict["x_NCHW"][0] == "NCHW")
assert(ldict["y"][0] == "NHWC")
def test_broadcast_to():
x = sym.Variable("x", shape=(4, 1))
y = sym.broadcast_to(x, shape=(0, 4), name="y")
g, ldict = correct_layout(y, "HW")
assert(ldict["x"][0] == "HW")
assert(ldict["y"][0] == "__undef__")
# second pass will insert layout transform
g, ldict = correct_layout(g, "HW16h")
assert(ldict["x"][0] == "HW16h")
assert(ldict["x_HW"][0] == "HW")
assert(ldict["y"][0] == "__undef__")
def test_broadcast_binary():
x = sym.Variable("x", shape=(1, 16, 512, 512))
y = sym.Variable("y", shape=(16, 512, 512))
z = sym.broadcast_add(x, y, name="z")
g, ldict = correct_layout(z, {"x": "NCHW", "y": "CHW"})
assert(ldict["x"][0] == "NCHW")
assert(ldict["y"][0] == "CHW")
assert(ldict["z"][0] == "NCHW")
# prior to keep the left layout if they do not match.
g, ldict = correct_layout(g, {"x": "NCHW16c", "y": "CHW"})
assert(ldict["x"][0] == "NCHW16c")
assert(ldict["y"][0] == "CHW")
assert(ldict["y_CHW16c"][0] == "CHW16c")
assert(ldict["z"][0] == "NCHW16c")
# broadcast_add(HCW16c, N16nCH16cW)
g, ldict = correct_layout(z, {"x": "HCW16c", "y": "N16nCH16cW"})
assert(ldict["x"][0] == "HCW16c")
assert(ldict["y"][0] == "N16nCH16cW")
assert(ldict["x_CH16cW"][0] == "CH16cW")
assert(ldict["z"][0] == "N16nCH16cW")
def test_reduce():
x = sym.Variable("x", shape=(1, 16, 512, 512))
y = sym.sum(x, name="y", axis=1)
g, ldict = correct_layout(y, "NCHW")
assert(ldict["x"][0] == "NCHW")
assert(ldict["y"][0] == "__undef__")
# second pass will insert layout transform
g, ldict = correct_layout(g, "NCHW16c")
assert(ldict["x"][0] == "NCHW16c")
assert(ldict["x_NCHW"][0] == "NCHW")
assert(ldict["y"][0] == "__undef__")
if __name__ == "__main__":
test_dense()
test_matmul()
test_concatenate()
test_expand_dims()
test_split()
test_batchnorm()
test_flatten()
test_conv2d()
test_conv2d_transpose()
test_max_pool2d()
test_global_pool2d()
test_reshape()
test_transpose()
test_broadcast_to()
test_broadcast_binary()
test_reduce()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment