lowered_func.h 3.86 KB
Newer Older
1
/*!
2
 *  Copyright (c) 2017 by Contributors
tqchen committed
3
 * \file tvm/lowered_func.h
4 5
 * \brief Information about a lowered TVM function.
 *  This data structure is final step toward codegen.
6
 */
7 8
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_
9 10 11 12

#include <ir/FunctionBase.h>
#include <string>

13 14 15
#include "base.h"
#include "expr.h"
#include "tensor.h"
16
#include "tvm/node/container.h"
17 18 19 20 21 22 23 24 25 26 27 28 29

namespace tvm {

// Internal node container of lowered function.
class LoweredFuncNode;

/*!
 * \brief LoweredFunc represents function after lowering.
 *  This is the final IR representation before codegen.
 */
class LoweredFunc : public FunctionRef {
 public:
  LoweredFunc() {}
30
  explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {}
31 32 33 34 35 36 37 38 39
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const LoweredFuncNode* operator->() const;
  /*! \brief specify container node */
  using ContainerType = LoweredFuncNode;
};

40 41 42 43 44 45 46 47 48 49
/*! \brief specific type of lowered function */
enum LoweredFuncType : int {
  /*! \brief Function that can mix device and host calls */
  kMixedFunc = 0,
  /*! \brief Only contains host code */
  kHostFunc = 1,
  /*! \brief Only contains device code */
  kDeviceFunc = 2
};

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
 public:
  /*! \brief The name of the function */
  std::string name;
  /*!
   * \brief The arguments of the function
   *  This function can only take pod type(int, float) and void* as arguments.
   */
  Array<Var> args;
  /*!
   * \brief The IterVar axis of threads
   *  Each axis need host function to specify a size.
   * \note Calling convention into LoweredFunc
   *
   * Assume we have a LoweredFunc f, a call into f
   *   Call(f, arg1, arg2, ..., arg_n,
   *        size_axis_1, size_axis_2, ... size_axis_m)
   *
   * Here n = len(args), m = len(thread_axis)
   *
   * The CodeGen should take this and translate this call
   * to corresponding API specific kernel launchs or function calls.
   */
  Array<IterVar> thread_axis;
  /*!
   * \brief The hint data type of Var handles defined in LetStmt
   *  Can be used as hint when generating type signiture.
   *  The creation rule is given by
   *  handle_data_type[var_handle] = make_const(the_type, 0);
   *
   * \note Expr is used instead Type, because Type cannot be hold by Map.
   *  constant Expr of given type is used.
   */
  Map<Var, Expr> handle_data_type;
85 86
  /*! \brief The type of the function */
  LoweredFuncType func_type{kMixedFunc};
87 88
  /*! \brief Whether this function is packed function */
  bool is_packed_func{true};
89 90 91 92 93
  /*!
   * \brief Whether function ensures that argument pointers do not alias.
   *  This corresponds to restrict keyword in C.
   */
  bool is_restricted{true};
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  /*! \brief The body statment of the function */
  Stmt body;
  /*! \return name of the operation */
  const std::string& func_name() const final {
    return name;
  }
  // there is no return value, but return 1
  // to enable Call into this function.
  int num_outputs() const final {
    return 1;
  }
  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
    v->Visit("args", &args);
    v->Visit("thread_axis", &thread_axis);
    v->Visit("handle_data_type", &handle_data_type);
110
    v->Visit("func_type", &func_type);
111
    v->Visit("is_packed_func", &is_packed_func);
112
    v->Visit("is_restricted", &is_restricted);
113 114 115 116
    v->Visit("body", &body);
  }

  static constexpr const char* _type_key = "LoweredFunc";
117
  TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node);
118 119 120 121 122 123 124 125 126
};

// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
  return static_cast<const LoweredFuncNode*>(node_.get());
}

}  // namespace tvm

127 128 129 130 131 132 133 134 135 136
namespace std {
template <>
struct hash<::tvm::LoweredFunc> {
  std::size_t operator()(const ::tvm::LoweredFunc& k) const {
    return k.hash();
  }
};
}

#endif  // TVM_LOWERED_FUNC_H_