codegen_c.h 10.9 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file codegen_c.h
 * \brief Common utilities to generated C style code.
 */
24 25
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_H_
26

27 28
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
29
#include <tvm/tir/function.h>
30
#include <tvm/tir/stmt_functor.h>
31
#include <tvm/target/codegen.h>
32
#include <tvm/tir/lowered_func.h>
33
#include <tvm/runtime/container.h>
34
#include <string>
35
#include <vector>
36
#include <unordered_map>
37
#include <unordered_set>
38
#include "codegen_source_base.h"
39 40 41 42

namespace tvm {
namespace codegen {

43
using namespace tir;
44 45 46 47
/*!
 * \brief A base class to generate C code.
 *
 *  CodeGenC have two modes: generate SSA formed C code or normal form.
48 49 50 51 52
 *
 * **NOTE** CodeGenC does not aim at generating C codes consumed by MSVC or GCC,
 * Rather, it's providing infrastructural abstraction for C variants like CUDA
 * and OpenCL-C. You might find some odd variant features, e.g., type `int3` for
 * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
53
 */
54
class CodeGenC :
55
      public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
56 57
      public StmtFunctor<void(const Stmt&)>,
      public CodeGenSourceBase {
58 59
 public:
  /*!
60 61
   * \brief Initialize the code generator.
   * \param output_ssa Whether output SSA.
62
   */
63 64 65 66
  void Init(bool output_ssa);
  /*!
   * \brief Add the function to the generated module.
   * \param f The function to be compiled.
67
   * \param whether to append return 0 in the end.
68
   */
69
  void AddFunction(const PrimFunc& f);
70 71 72 73 74
  /*!
   * \brief Finalize the compilation and return the code.
   * \return The code.
   */
  std::string Finish();
75 76 77 78
  /*!
   * \brief Print the Stmt n to CodeGenC->stream
   * \param n The statement to be printed.
   */
79 80 81
  void PrintStmt(const Stmt& n) {
    VisitStmt(n);
  }
82 83 84 85 86
  /*!
   * \brief Print the expression n(or its ssa id if in ssa mode) into os
   * \param n The expression to be printed.
   * \param os The output stream
   */
87
  void PrintExpr(const PrimExpr& n, std::ostream& os);
88 89 90 91
  /*!
   * \brief Same as PrintExpr, but simply returns result string
   * \param n The expression to be printed.
   */
92
  std::string PrintExpr(const PrimExpr& n) {
93 94 95 96
    std::ostringstream os;
    PrintExpr(n, os);
    return os.str();
  }
97
  // The following parts are overloadable print operations.
98
  /*!
99 100 101 102 103 104 105 106 107 108
   * \brief Print the function header before the argument list
   *
   *  Example: stream << "void";
   */
  virtual void PrintFuncPrefix(); // NOLINT(*)
  /*!
   * \brief Print the final return at the end the function.
   */
  virtual void PrintFinalReturn(); // NOLINT(*)
  /*!
109 110 111
   * \brief Insert statement before function body.
   * \param f The function to be compiled.
   */
112
  virtual void PreFunctionBody(const PrimFunc& f) {}
113
  /*!
114 115 116
   * \brief Initialize codegen state for generating f.
   * \param f The function to be compiled.
   */
117
  virtual void InitFuncState(const PrimFunc& f);
118
  // expression
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  void VisitExpr_(const VarNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LoadNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LetNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const CallNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const AddNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const SubNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const MulNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const DivNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const ModNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const MinNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const MaxNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const EQNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const NENode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LTNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LENode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const GTNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const GENode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const AndNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const OrNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const CastNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const NotNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const SelectNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const RampNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const ShuffleNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const BroadcastNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const IntImmNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const FloatImmNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const StringImmNode* op, std::ostream& os) override;  // NOLINT(*)
147
  // statment
148 149 150 151 152 153 154 155
  void VisitStmt_(const LetStmtNode* op) override;
  void VisitStmt_(const StoreNode* op) override;
  void VisitStmt_(const ForNode* op) override;
  void VisitStmt_(const IfThenElseNode* op) override;
  void VisitStmt_(const AllocateNode* op) override;
  void VisitStmt_(const AttrStmtNode* op) override;
  void VisitStmt_(const AssertStmtNode* op) override;
  void VisitStmt_(const EvaluateNode* op) override;
156
  void VisitStmt_(const SeqStmtNode* op) override;
157
  void VisitStmt_(const ProducerConsumerNode* op) override;
158
  /*!
159 160
   * Print Type represetnation of type t.
   * \param t The type representation.
161
   * \param os The stream to print the ctype into
162
   */
163
  virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
164
  /*!
165 166 167 168 169 170
   * Print Type represetnation of type type.
   * \param type The type representation.
   * \param os The stream to print the ctype into
   */
  virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*)
  /*!
171
   * \brief Print expr representing the thread tag
172
   * \param IterVar iv The thread index to be binded;
173
   */
174 175
  virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
  virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
176
  virtual void PrintStorageSync(const CallNode* op);  // NOLINT(*)
177 178
  // Binary vector op.
  virtual void PrintVecBinaryOp(
179
      const std::string&op, DataType op_type,
180
      PrimExpr lhs, PrimExpr rhs, std::ostream& os);  // NOLINT(*)
181
  // print vector load
182
  virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base);
183
  // print vector store
184
  virtual void PrintVecStore(const VarNode* buffer,
185
                             DataType t, PrimExpr base,
186
                             const std::string& value);  // NOLINT(*)
187
  // print load of single element
188
  virtual void PrintVecElemLoad(
189
      const std::string& vec, DataType t, int i, std::ostream& os);  // NOLINT(*)
190
  // print store of single element.
191
  virtual void PrintVecElemStore(
192
      const std::string& vec, DataType t, int i, const std::string& value);
193
  // Get a cast type from to
194
  virtual std::string CastFromTo(std::string value, DataType from, DataType target);
195

196
 protected:
197 198
  // Print reference to struct location
  std::string GetStructRef(
199
      DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
200
  // Print reference to a buffer as type t in index.
201
  virtual std::string GetBufferRef(
202
      DataType t, const VarNode* buffer, PrimExpr index);
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229

  /*!
   * \brief Handle volatile loads.
   *
   * This is to workaround a bug in CUDA cuda_fp16.h. Volatile accesses
   * to shared memory are required for reductions. However, __half class
   * does not implement volatile member functions. CUDA codegen will cast
   * away volatile qualifier from CUDA __half types.
   */
  virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op,
                                   std::ostream& os) {
    // By default, do nothing but print the loaded value.
    os << value;
  }

  /*!
   * \brief Check if scope is part of type in the target language.
   *
   * **NOTE** In OpenCL, __local is part of type, so "__local int *"
   * is legal. This is not the case for CUDA, where "__shared__"
   * or "__constant__" is not part of type but a storage class (like
   * C/C++ static).
   */
  virtual bool IsScopePartOfType() const {
    return true;
  }

230 231 232 233 234
  /*!
   * \brief If buffer is allocated as type t.
   * \param buf_var The buffer variable.
   * \param t The type to be checked.
   */
235
  bool HandleTypeMatch(const VarNode* buf_var, DataType t) const;
236 237 238 239 240
  /*!
   * \brief Register the data type of buf_var
   * \param buf_var The buffer variable.
   * \param t The type to be checked.
   */
241
  void RegisterHandleType(const VarNode* buf_var, DataType t);
242 243
  // override
  void PrintSSAAssign(
244
      const std::string& target, const std::string& src, DataType t) final;
245 246
  /*! \brief reserves common C keywords */
  void ReserveKeywordsAsUnique();
247

248 249 250 251 252
  /*! \brief Check if buf_var is volatile or not. */
  bool IsVolatile(const VarNode *buf_var) const {
    return volatile_buf_.count(buf_var) != 0;
  }

253 254 255 256 257 258 259
  /*! \brief restrict keyword */
  std::string restrict_keyword_{""};
  /*! \brief the storage scope of allocation */
  std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
  /*! \brief the data type of allocated buffers */
  std::unordered_map<const VarNode*, DataType> handle_data_type_;

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
  /*!
   * \brief A RAII utility class for emitting code in a scoped region.
   */
  class EnterScopeRAII {
    // The codegen context.
    CodeGenC* cg;

    // The new scope level.
    int scope;

   public:
    explicit EnterScopeRAII(CodeGenC* cg) : cg(cg) {
      cg->PrintIndent();
      cg->stream << "{\n";
      scope = cg->BeginScope();
    }
    ~EnterScopeRAII() {
      cg->EndScope(scope);
      cg->PrintIndent();
      cg->stream << "}\n";
    }
  };

283 284
 private:
  /*! \brief whether to print in SSA form */
285 286
  bool print_ssa_form_{false};
  /*! \brief set of volatile buf access */
287
  std::unordered_set<const VarNode*> volatile_buf_;
288 289 290 291
};

}  // namespace codegen
}  // namespace tvm
292
#endif  // TVM_TARGET_SOURCE_CODEGEN_C_H_