lowered_func.h 4.58 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file tvm/tir/lowered_func.h
22 23
 * \brief Information about a lowered TVM function.
 *  This data structure is final step toward codegen.
24
 */
25 26
#ifndef TVM_TIR_LOWERED_FUNC_H_
#define TVM_TIR_LOWERED_FUNC_H_
27

28 29 30
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
31 32 33
#include <string>

namespace tvm {
34
namespace tir {
35 36 37 38 39 40 41 42

// Internal node container of lowered function.
class LoweredFuncNode;

/*!
 * \brief LoweredFunc represents function after lowering.
 *  This is the final IR representation before codegen.
 */
43
class LoweredFunc : public FunctionRef {
44 45
 public:
  LoweredFunc() {}
46
  explicit LoweredFunc(ObjectPtr<Object> n) : FunctionRef(n) {}
47 48 49 50 51 52 53 54 55
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const LoweredFuncNode* operator->() const;
  /*! \brief specify container node */
  using ContainerType = LoweredFuncNode;
};

56 57 58 59 60 61 62 63 64 65
/*! \brief specific type of lowered function */
enum LoweredFuncType : int {
  /*! \brief Function that can mix device and host calls */
  kMixedFunc = 0,
  /*! \brief Only contains host code */
  kHostFunc = 1,
  /*! \brief Only contains device code */
  kDeviceFunc = 2
};

66
/*! \brief Node container of LoweredFunc */
67
class LoweredFuncNode : public tir::FunctionBaseNode {
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
 public:
  /*! \brief The name of the function */
  std::string name;
  /*!
   * \brief The arguments of the function
   *  This function can only take pod type(int, float) and void* as arguments.
   */
  Array<Var> args;
  /*!
   * \brief The IterVar axis of threads
   *  Each axis need host function to specify a size.
   * \note Calling convention into LoweredFunc
   *
   * Assume we have a LoweredFunc f, a call into f
   *   Call(f, arg1, arg2, ..., arg_n,
   *        size_axis_1, size_axis_2, ... size_axis_m)
   *
   * Here n = len(args), m = len(thread_axis)
   *
   * The CodeGen should take this and translate this call
   * to corresponding API specific kernel launchs or function calls.
   */
  Array<IterVar> thread_axis;
  /*!
   * \brief The hint data type of Var handles defined in LetStmt
   *  Can be used as hint when generating type signiture.
   *  The creation rule is given by
   *  handle_data_type[var_handle] = make_const(the_type, 0);
   *
   * \note Expr is used instead Type, because Type cannot be hold by Map.
   *  constant Expr of given type is used.
   */
100
  Map<Var, PrimExpr> handle_data_type;
101 102
  /*! \brief The type of the function */
  LoweredFuncType func_type{kMixedFunc};
103 104
  /*! \brief Whether this function is packed function */
  bool is_packed_func{true};
105 106 107 108 109
  /*!
   * \brief Whether function ensures that argument pointers do not alias.
   *  This corresponds to restrict keyword in C.
   */
  bool is_restricted{true};
110 111 112 113 114 115 116 117 118 119 120
  /*! \brief The body statment of the function */
  Stmt body;
  /*! \return name of the operation */
  const std::string& func_name() const final {
    return name;
  }
  // there is no return value, but return 1
  // to enable Call into this function.
  int num_outputs() const final {
    return 1;
  }
121
  void VisitAttrs(AttrVisitor* v) {
122 123 124 125
    v->Visit("name", &name);
    v->Visit("args", &args);
    v->Visit("thread_axis", &thread_axis);
    v->Visit("handle_data_type", &handle_data_type);
126
    v->Visit("func_type", &func_type);
127
    v->Visit("is_packed_func", &is_packed_func);
128
    v->Visit("is_restricted", &is_restricted);
129 130 131 132
    v->Visit("body", &body);
  }

  static constexpr const char* _type_key = "LoweredFunc";
133
  TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object);
134 135 136 137
};

// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
138
  return static_cast<const LoweredFuncNode*>(get());
139
}
140
}  // namespace tir
141 142
}  // namespace tvm

143 144
namespace std {
template <>
145
struct hash<::tvm::tir::LoweredFunc> : public tvm::ObjectHash {
146 147 148
};
}

149
#endif  // TVM_TIR_LOWERED_FUNC_H_