op_common.h 2.68 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/*!
 *  Copyright (c) 2018 by Contributors
 * \file op_common.h
 * \brief A set of utilities and common functionality
 * for relay ops.
 */
#ifndef TVM_RELAY_OP_OP_COMMON_H_
#define TVM_RELAY_OP_OP_COMMON_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <vector>

namespace tvm {
namespace relay {

template<typename T>
std::vector<T> AsVector(const Array<T> &array) {
    std::vector<T> result;
    result.reserve(array.size());
    for (const T& ele : array) {
        result.push_back(ele);
    }
    return result;
}

雾雨魔理沙 committed
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
/*! Quick helper macro
 * - Expose a positional make function to construct the node.
 * - Register op to the registry.
 *
 * We make the decision to always only expose positional argument.
 * We will do rewrapping in the frontend to support language
 * sugars such as keyword arguments and default value.
 *
 * \param Prefix the prefix of the registry, for example, "relay.op._make.".
 *
 * \param OpName the name of registry.
 */
#define RELAY_REGISTER_UNARY_OP(Prefix, OpName)           \
  TVM_REGISTER_API(Prefix OpName)                         \
    .set_body_typed<Expr(Expr)>([](Expr data) {           \
        static const Op& op = Op::Get(OpName);            \
        return CallNode::make(op, {data}, Attrs(), {});   \
      });                                                 \
  RELAY_REGISTER_OP(OpName)                               \
    .set_num_inputs(1)                                    \
    .add_argument("data", "Tensor", "The input tensor.")

/*! Quick helper macro
 * - Expose a positional make function to construct the node.
 * - Register op to the registry.
 *
 * We make the decision to always only expose positional argument.
 * We will do rewrapping in the frontend to support language
 * sugars such as keyword arguments and default value.
 *
 * \param Prefix the prefix of the registry, for example, "relay.op._make.".
 *
 * \param OpName the name of registry.
 */
#define RELAY_REGISTER_BINARY_OP(Prefix, OpName)                  \
  TVM_REGISTER_API(Prefix OpName)                                 \
    .set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {    \
        static const Op& op = Op::Get(OpName);                    \
        return CallNode::make(op, {lhs, rhs}, Attrs(), {});       \
      });                                                         \
  RELAY_REGISTER_OP(OpName)                                       \
    .set_num_inputs(2)                                            \
    .add_argument("lhs", "Tensor", "The left hand side tensor.")  \
    .add_argument("rhs", "Tensor", "The right hand side tensor.") \
    .add_type_rel("Broadcast", BroadcastRel)

73 74 75 76
}  // namespace relay
}  // namespace tvm

#endif  // TVM_RELAY_OP_OP_COMMON_H_