lowered_func.h 4.52 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
tqchen committed
21
 * \file tvm/lowered_func.h
22 23
 * \brief Information about a lowered TVM function.
 *  This data structure is final step toward codegen.
24
 */
25 26
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_
27 28 29

#include <string>

30 31 32
#include "base.h"
#include "expr.h"
#include "tensor.h"
33
#include "tvm/node/container.h"
34 35 36 37 38 39 40 41 42 43

namespace tvm {

// Internal node container of lowered function.
class LoweredFuncNode;

/*!
 * \brief LoweredFunc represents function after lowering.
 *  This is the final IR representation before codegen.
 */
44
class LoweredFunc : public ir::FunctionRef {
45 46
 public:
  LoweredFunc() {}
47
  explicit LoweredFunc(ObjectPtr<Object> n) : FunctionRef(n) {}
48 49 50 51 52 53 54 55 56
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const LoweredFuncNode* operator->() const;
  /*! \brief specify container node */
  using ContainerType = LoweredFuncNode;
};

57 58 59 60 61 62 63 64 65 66
/*! \brief specific type of lowered function */
enum LoweredFuncType : int {
  /*! \brief Function that can mix device and host calls */
  kMixedFunc = 0,
  /*! \brief Only contains host code */
  kHostFunc = 1,
  /*! \brief Only contains device code */
  kDeviceFunc = 2
};

67
/*! \brief Node container of LoweredFunc */
68
class LoweredFuncNode : public ir::FunctionBaseNode {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
 public:
  /*! \brief The name of the function */
  std::string name;
  /*!
   * \brief The arguments of the function
   *  This function can only take pod type(int, float) and void* as arguments.
   */
  Array<Var> args;
  /*!
   * \brief The IterVar axis of threads
   *  Each axis need host function to specify a size.
   * \note Calling convention into LoweredFunc
   *
   * Assume we have a LoweredFunc f, a call into f
   *   Call(f, arg1, arg2, ..., arg_n,
   *        size_axis_1, size_axis_2, ... size_axis_m)
   *
   * Here n = len(args), m = len(thread_axis)
   *
   * The CodeGen should take this and translate this call
   * to corresponding API specific kernel launchs or function calls.
   */
  Array<IterVar> thread_axis;
  /*!
   * \brief The hint data type of Var handles defined in LetStmt
   *  Can be used as hint when generating type signiture.
   *  The creation rule is given by
   *  handle_data_type[var_handle] = make_const(the_type, 0);
   *
   * \note Expr is used instead Type, because Type cannot be hold by Map.
   *  constant Expr of given type is used.
   */
  Map<Var, Expr> handle_data_type;
102 103
  /*! \brief The type of the function */
  LoweredFuncType func_type{kMixedFunc};
104 105
  /*! \brief Whether this function is packed function */
  bool is_packed_func{true};
106 107 108 109 110
  /*!
   * \brief Whether function ensures that argument pointers do not alias.
   *  This corresponds to restrict keyword in C.
   */
  bool is_restricted{true};
111 112 113 114 115 116 117 118 119 120 121
  /*! \brief The body statment of the function */
  Stmt body;
  /*! \return name of the operation */
  const std::string& func_name() const final {
    return name;
  }
  // there is no return value, but return 1
  // to enable Call into this function.
  int num_outputs() const final {
    return 1;
  }
122
  void VisitAttrs(AttrVisitor* v) {
123 124 125 126
    v->Visit("name", &name);
    v->Visit("args", &args);
    v->Visit("thread_axis", &thread_axis);
    v->Visit("handle_data_type", &handle_data_type);
127
    v->Visit("func_type", &func_type);
128
    v->Visit("is_packed_func", &is_packed_func);
129
    v->Visit("is_restricted", &is_restricted);
130 131 132 133
    v->Visit("body", &body);
  }

  static constexpr const char* _type_key = "LoweredFunc";
134
  TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node);
135 136 137 138
};

// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
139
  return static_cast<const LoweredFuncNode*>(get());
140 141 142 143
}

}  // namespace tvm

144 145
namespace std {
template <>
146
struct hash<::tvm::LoweredFunc> : public tvm::NodeHash {
147 148 149 150
};
}

#endif  // TVM_LOWERED_FUNC_H_