/*! * Copyright (c) 2017 by Contributors * \file pattern_util.h * \brief Utilities for doing various pattern matching in graph. */ #ifndef NNVM_COMPILER_PATTERN_UTIL_H_ #define NNVM_COMPILER_PATTERN_UTIL_H_ #include <nnvm/graph.h> #include <vector> #include <utility> #include <string> namespace nnvm { namespace compiler { /*! * \brief find axis in oshape, such that: * bias_shape = [1,1, ... oshape[axis], 1,1,] * * This is used to detect bias or scaling factor on channel dimension. * \param oshape The output shape * \param bias_shape The shape of bias or scaling factor. * \return Pair of matched axis in o shape and bias_shape if found. */ inline std::pair<int, int> MatchBroadcast1DAxis( const TShape& oshape, const TShape& bias_shape) { dim_t axis_dim = bias_shape.ndim(); for (dim_t i = bias_shape.ndim(); i != 0; --i, --axis_dim) { if (bias_shape[i - 1] != 1) break; } // everything is 1 if (axis_dim == 0) { return {oshape.ndim() - bias_shape.ndim(), 0}; } axis_dim = axis_dim - 1; // The bias shape is not 1D for (dim_t i = 0; i < axis_dim; ++i) { if (bias_shape[i] != 1) return {-1, -1}; } int axis = static_cast<int>( oshape.ndim() - bias_shape.ndim() + axis_dim); if (oshape[axis] != bias_shape[axis_dim]) return {-1, -1}; return {axis, axis_dim}; } /*! * \brief Expand bias dimension to match needed axis. * * \param bias The bias NodeEntry * \param out_dim output dimension. * \param bias_dim The current bias dimension. * \param axis The axis we want to match on. */ inline NodeEntry ExpandBiasToMatchAxis(NodeEntry bias, int out_dim, int bias_dim, int axis) { if (bias_dim != 1) { bias = MakeNode("squeeze", bias.node->attrs.name + "_sqz", {bias}); } int num_pad_axis = out_dim - axis - 1; if (num_pad_axis > 0) { std::unordered_map<std::string, std::string> kwargs{ {"axis", "1"}, {"num_newaxis", std::to_string(num_pad_axis)}}; return MakeNode("expand_dims", bias.node->attrs.name + "_expand", {bias}, kwargs); } else { return bias; } } /*! * \brief Get the reference count of each node. * \param idx The IndexedGraph * \return ref_count vector of length number nodes. */ inline std::vector<uint32_t> GetNodeRefCounts(const IndexedGraph& idx) { std::vector<uint32_t> ref_count(idx.num_nodes(), 0); for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; for (const auto& e : inode.inputs) { ++ref_count[e.node_id]; } } for (const auto& e : idx.outputs()) { // this line will realize all the outputs ref_count[e.node_id] += 1; } return ref_count; } } // namespace compiler } // namespace nnvm #endif // NNVM_COMPILER_PATTERN_UTIL_H_