/*! * Copyright (c) 2017 by Contributors * \file nn_common.h * \brief Common utilities for nn ops. */ #ifndef NNVM_TOP_NN_NN_COMMON_H_ #define NNVM_TOP_NN_NN_COMMON_H_ #include <dmlc/logging.h> #include <dmlc/parameter.h> #include <nnvm/layout.h> #include <nnvm/top/nn.h> #include <string> #include <vector> #include <utility> #include <algorithm> namespace nnvm { namespace top { template<typename ParamType> inline uint32_t UseBiasNumInputs(const NodeAttrs& attrs) { const ParamType& param = get<ParamType>(attrs.parsed); return param.use_bias ? 3 : 2; } template<typename ParamType> inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) { const ParamType& param = nnvm::get<ParamType>(attrs.parsed); if (param.use_bias) { return {"data", "weight", "bias"}; } else { return {"data", "weight"}; } } /*! * \brief Convert shape in src_layout to shape in dst_layout * \param src original shape * \param src_layout layout of original shape * \param dst_layout target layout * \return shape in target layout */ inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout& dst_layout) { if (src_layout == dst_layout) { return src; } else if (!src_layout.defined()) { LOG(FATAL) << "cannot convert undefined layout to " << dst_layout; } else if (!dst_layout.defined()) { LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout"; } CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " << src_layout << " to " << dst_layout; TShape dst(dst_layout.ndim()); for (size_t i = 0; i < src_layout.ndim(); ++i) { Layout::LayoutDim src_dim = src_layout[i]; if (Layout::is_superdim(src_dim)) { int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_dim)); int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_dim)); int src_minor_pos = src_layout.indexof(Layout::to_subdim(src_dim)); int src_factor = src_layout.subsizeof(src_dim); int dst_factor = dst_layout.subsizeof(src_dim); uint32_t src_dim_size = src[i]; if (src_minor_pos >= 0) { CHECK_EQ(src_factor, src[src_minor_pos]) << "src shape " << src << " does not agree with layout " << src_layout; src_dim_size *= src_factor; } dst[dst_major_pos] = src_dim_size; if (dst_minor_pos >= 0) { CHECK_GT(dst_factor, 0); CHECK_LE(dst_factor, src_dim_size) << "Converting " << src << " from " << src_layout << " to " << dst_layout << ": cannot split dimension size of " << src_dim_size << " by " << dst_factor; dst[dst_major_pos] /= dst_factor; dst[dst_minor_pos] = dst_factor; } } } return dst; } } // namespace top } // namespace nnvm #endif // NNVM_TOP_NN_NN_COMMON_H_