op_common.h 10.9 KB
Newer Older
tqchen committed
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2017 by Contributors
 * \file op_common.h
 * \brief Common operator utilities
 */
#ifndef NNVM_TOP_OP_COMMON_H_
#define NNVM_TOP_OP_COMMON_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
11
#include <nnvm/top/tensor.h>
tqchen committed
12 13
#include <string>
#include <vector>
Xingjian Shi committed
14
#include <unordered_set>
tqchen committed
15 16 17 18 19

namespace nnvm {
namespace top {
/*!
 * \brief Parse keyword arguments as PType arguments and save to parsed
20
 * \tparam PType the parameter type.
tqchen committed
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
 * \param attrs The attributes.
 */
template<typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
  PType param;
  try {
    param.Init(attrs->dict);
  } catch (const dmlc::ParamError& e) {
    std::ostringstream os;
    os << e.what();
    os << ", in operator " << attrs->op->name << "("
       << "name=\"" << attrs->name << "\"";
    for (const auto& k : attrs->dict) {
      os << ", " << k.first << "=\"" << k.second << "\"";
    }
    os << ")";
    throw dmlc::ParamError(os.str());
  }
  attrs->parsed = std::move(param);
}

42 43 44 45 46 47 48 49 50 51 52 53 54
/*!
 * \brief Parse keyword arguments as PType arguments and save to parsed
 * \tparam PType the arameter type.
 * \param attrs The attributes.
 */
template<typename PType>
inline std::unordered_map<std::string, std::string>
ParamGetAttrDict(const nnvm::NodeAttrs& attrs) {
  std::unordered_map<std::string, std::string> dict = attrs.dict;
  nnvm::get<PType>(attrs.parsed).UpdateDict(&dict);
  return dict;
}

tqchen committed
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
/*! \brief check if shape is empty or contains unkown (0) dim. */
inline bool shape_is_none(const TShape& x) {
  return x.ndim() == 0 || x.Size() == 0;
}

/*! \brief check if type is none (-1) */
inline bool type_is_none(const int& x) {
  return x == -1;
}

/*! \brief check if shape is scalar({1}). */
inline bool shape_is_scalar(const TShape& x) {
  return x.ndim() == 1 && x.Size() == 1;
}

/*! \brief get string representation of shape */
inline std::string shape_string(const TShape& x) {
  std::ostringstream os;
  os << x;
  return os.str();
}

/*! \brief get string representation of shape */
inline std::string type_string(const int& x) {
  return std::to_string(x);
}

/*!
 * \brief Assign x to y. Checks for compatiblity when y is not empty.
 *  Allow missing dim in both x and y (as 0).
 * \param y target shape.
 * \param x source shape.
 * \return whether x and y are compatible.
 */
inline bool shape_assign(TShape *y, const TShape& x) {
  if (y->ndim() == 0) {
    *y = x;
    return true;
  } else if (y->ndim() != x.ndim()) {
    return x.ndim() == 0;
  } else {
    for (size_t i = 0; i < y->ndim(); ++i) {
      if ((*y)[i] == 0) {
        (*y)[i] = x[i];
      } else if ((*y)[i] != x[i] && x[i] != 0) {
        return false;
      }
    }
    return true;
  }
}

/*!
 * \brief Assign x to y. Checks for compatiblity when y is not -1.
 * \param y target type.
 * \param x source type.
 * \return whether x and y are compatible.
 */
inline bool type_assign(int *y, const int& x) {
  if (*y == -1) {
    *y = x;
    return true;
  } else if (*y != x && x != -1) {
    return false;
  }
  return true;
}

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
template<typename AttrType>
inline std::string attr_assign_error_msg(const NodeAttrs& attrs,
                                         int index, bool is_input,
                                         const AttrType& expected,
                                         const AttrType& actual,
                                         const char* attr_name) {
  static const auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
  static const auto& flist_outputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
  const auto& flist = is_input ? flist_inputs : flist_outputs;
  std::string name;
  if (flist.count(attrs.op)) {
    name = flist[attrs.op](attrs)[index];
  } else {
    name = (is_input ? "data" : "output") + std::to_string(index);
  }
  std::ostringstream msg;
  msg << "Operator " << attrs.op->name << "(";
  for (const auto& kv : attrs.dict) msg << kv.first << "=" << kv.second << ", ";
  msg << "name=" << attrs.name << ") expects " << name << "\'s " << attr_name
      << " to be " << expected << ", but got " << actual << ".";
  return msg.str();
}

tqchen committed
146
/*!
147
 * \brief macro assign shape to input if out is unknown otherwise check consistency
tqchen committed
148
 *  Use macro so we can see the error file more clearly
149
 * \param inputs the shape array to store the result
tqchen committed
150 151 152
 * \param index the index of in the array
 * \param shape the inferred shape
 */
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
#define NNVM_ASSIGN_INPUT_SHAPE(attrs, inputs, index, shape)             \
  {                                                                      \
    if (!shape_assign(&(inputs)[index], TShape(shape))) {                \
      LOG(FATAL) << attr_assign_error_msg(attrs, index, true, shape,     \
                                          (inputs)[index], "shape");     \
    }                                                                    \
  }

/*!
 * \brief macro assign shape to out if out is unknown otherwise check consistency
 *  Use macro so we can see the error file more clearly
 * \param inputs the shape array to store the result
 * \param index the index of in the array
 * \param shape the inferred shape
 */
#define NNVM_ASSIGN_OUTPUT_SHAPE(attrs, outputs, index, shape)           \
  {                                                                      \
    if (!shape_assign(&(outputs)[index], TShape(shape))) {               \
      LOG(FATAL) << attr_assign_error_msg(attrs, index, false, shape,    \
                                          (outputs)[index], "shape");    \
    }                                                                    \
tqchen committed
174 175 176 177 178
  }

/*!
 * \brief macro assign type to out if out is unknown (-1) otherwise check consistency
 *  Use macro so we can see the error file more clearly
179
 * \param inputs the type array to store the result
tqchen committed
180 181 182
 * \param index the index of in the array
 * \param type the inferred type
 */
183 184 185 186 187 188
#define NNVM_ASSIGN_INPUT_TYPE(attrs, inputs, index, type)               \
  {                                                                      \
    if (!type_assign(&(inputs)[index], type)) {                          \
      LOG(FATAL) << attr_assign_error_msg(attrs, index, true, type,      \
                                          (inputs)[index], "type");      \
    }                                                                    \
tqchen committed
189 190
  }

191 192 193 194 195 196 197 198 199 200 201 202 203 204
/*!
 * \brief macro assign type to out if out is unknown (-1) otherwise check consistency
 *  Use macro so we can see the error file more clearly
 * \param inputs the type array to store the result
 * \param index the index of in the array
 * \param type the inferred type
 */
#define NNVM_ASSIGN_OUTPUT_TYPE(attrs, outputs, index, type)             \
  {                                                                      \
    if (!type_assign(&(outputs)[index], type)) {                         \
      LOG(FATAL) << attr_assign_error_msg(attrs, index, false, type,     \
                                          (outputs)[index], "type");     \
    }                                                                    \
  }
tqchen committed
205

206 207 208 209 210 211 212
#define NNVM_ASSIGN_LAYOUT(outputs, index, layout)                       \
  {                                                                      \
    if (layout.defined()) {                                              \
      (outputs)[index] = layout;                                         \
    }                                                                    \
  }

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
/*!
 * \brief macro assign rhs shape to lhs
 *  Use macro so we can see the error file more clearly
 * \param lhs lhs shape
 * \param rhs rhs shape
 */
#define SHAPE_ASSIGN(lhs, rhs)                                \
  if ((lhs).ndim() == 0) (lhs) = (rhs);                       \
  else                                                        \
    CHECK_EQ(lhs, rhs) << "shape inference inconsistent";     \

/*!
 * \brief macro assign rhs type to lhs
 *  Use macro so we can see the error file more clearly
 * \param lhs lhs type
 * \param rhs rhs type
 */
#define DTYPE_ASSIGN(lhs, rhs)                                \
  if ((lhs) == -1) (lhs) = (rhs);                             \
  else                                                        \
    CHECK_EQ(lhs, rhs) << "type inference inconsistent";     \

tqchen committed
235 236 237 238 239 240 241 242 243 244 245 246 247 248
// simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs,
                      std::vector<TShape> *ishape,
                      std::vector<TShape> *oshape) {
  if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
  for (TShape& pshape : *oshape) {
    pshape = (*ishape)[0];
  }
  for (TShape& pshape : *ishape) {
    pshape = (*ishape)[0];
  }
  return true;
}

249
// return shape from node attrs
250
template<typename PType>
251 252 253
inline bool ZeroShape(const NodeAttrs& attrs,
                      std::vector<TShape> *ishape,
                      std::vector<TShape> *oshape) {
254
  const TShape& ts = dmlc::get<PType>(attrs.parsed).shape;
255 256 257 258 259 260 261 262
  if (ts.ndim() != 0) {
    SHAPE_ASSIGN(oshape->at(0), ts);
    return true;
  } else {
    return false;
  }
}

263 264 265 266 267 268 269 270
// do not infer layout
inline bool ZeroLayout(const NodeAttrs& attrs,
                       std::vector<Layout> *in_layouts,
                       const std::vector<Layout> *last_in_layouts,
                       std::vector<Layout> *out_layouts) {
  return true;
}

271 272 273 274 275 276 277 278 279 280 281 282
// simply assign output shape or type from input
template<typename AttrType, int in_index, int out_index>
inline bool AssignOutputAttr(const NodeAttrs& attrs,
                              std::vector<AttrType> *in_attrs,
                              std::vector<AttrType> *out_attrs) {
  CHECK_LT(in_index, in_attrs->size());
  CHECK_LT(out_index, out_attrs->size());
  const TShape &dshape = in_attrs->at(in_index);
  NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, out_index, dshape);
  return true;
}

283
// return type from node attrs
284
template<typename PType>
285 286 287
inline bool ZeroType(const NodeAttrs& attrs,
                     std::vector<int> *iattr,
                     std::vector<int> *oattr) {
288
  int dtype = dmlc::get<PType>(attrs.parsed).dtype;
289 290 291 292
  DTYPE_ASSIGN(oattr->at(0), dtype);
  return true;
}

293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
// Make zero grad node
inline std::vector<NodeEntry> MakeZeroGradNodes(
  const NodePtr& n,
  const std::vector<NodeEntry>& ograds) {
  std::vector<NodeEntry> ret;
  for (uint32_t i = 0; i < n->num_inputs(); ++i) {
    std::ostringstream os;
    ret.push_back(MakeNode("zeros_like", n->attrs.name + "_zero_grad",
                           {n->inputs[i]}));
  }
  return ret;
}

// Helper to make gradient node
inline std::vector<NodeEntry> MakeGradNode(
  const char* op_name,
  const NodePtr& n,
  std::vector<NodeEntry> inputs,
311
  std::unordered_map<std::string, std::string> attr = {{}}) {
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
  NodePtr p = Node::Create();
  p->attrs.op = nnvm::Op::Get(op_name);
  p->attrs.name = n->attrs.name + "_grad";
  p->inputs = std::move(inputs);
  p->attrs.dict = std::move(attr);
  if (p->attrs.op->attr_parser) {
    p->attrs.op->attr_parser(&p->attrs);
  }
  std::vector<NodeEntry> ret;
  for (uint32_t i = 0; i < p->num_outputs(); ++i) {
    ret.emplace_back(NodeEntry{p, i, 0});
  }
  return ret;
}


tqchen committed
328 329 330 331
}  // namespace top
}  // namespace nnvm

#endif  // NNVM_TOP_OP_COMMON_H_