codegen_c.h 7.65 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*!
 *  Copyright (c) 2016 by Contributors
 * \file codegen_c.h
 * \brief Common utilities to generated C style code.
 */
#ifndef TVM_CODEGEN_CODEGEN_C_H_
#define TVM_CODEGEN_CODEGEN_C_H_

#include <tvm/ir.h>
10
#include <tvm/ir_functor_ext.h>
11 12
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
13
#include <string>
14
#include <vector>
15
#include <unordered_map>
16
#include <unordered_set>
17
#include "codegen_source_base.h"
18 19 20 21

namespace tvm {
namespace codegen {

22
using namespace ir;
23 24 25 26
/*!
 * \brief A base class to generate C code.
 *
 *  CodeGenC have two modes: generate SSA formed C code or normal form.
27 28 29 30 31
 *
 * **NOTE** CodeGenC does not aim at generating C codes consumed by MSVC or GCC,
 * Rather, it's providing infrastructural abstraction for C variants like CUDA
 * and OpenCL-C. You might find some odd variant features, e.g., type `int3` for
 * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
32
 */
33 34
class CodeGenC :
      public ExprFunctor<void(const Expr&, std::ostream&)>,
35 36
      public StmtFunctor<void(const Stmt&)>,
      public CodeGenSourceBase {
37 38
 public:
  /*!
39 40
   * \brief Initialize the code generator.
   * \param output_ssa Whether output SSA.
41
   */
42 43 44 45 46 47 48 49 50 51 52
  void Init(bool output_ssa);
  /*!
   * \brief Add the function to the generated module.
   * \param f The function to be compiled.
   */
  void AddFunction(LoweredFunc f);
  /*!
   * \brief Finalize the compilation and return the code.
   * \return The code.
   */
  std::string Finish();
53 54 55 56
  /*!
   * \brief Print the Stmt n to CodeGenC->stream
   * \param n The statement to be printed.
   */
57 58 59
  void PrintStmt(const Stmt& n) {
    VisitStmt(n);
  }
60 61 62 63 64
  /*!
   * \brief Print the expression n(or its ssa id if in ssa mode) into os
   * \param n The expression to be printed.
   * \param os The output stream
   */
65
  void PrintExpr(const Expr& n, std::ostream& os);
66 67 68 69
  /*!
   * \brief Same as PrintExpr, but simply returns result string
   * \param n The expression to be printed.
   */
70
  std::string PrintExpr(const Expr& n) {
71 72 73 74
    std::ostringstream os;
    PrintExpr(n, os);
    return os.str();
  }
75
  // The following parts are overloadable print operations.
76
  /*!
77 78 79 80 81
   * \brief Insert statement before function body.
   * \param f The function to be compiled.
   */
  virtual void PreFunctionBody(LoweredFunc f) {}
  /*!
82 83 84 85
   * \brief Initialize codegen state for generating f.
   * \param f The function to be compiled.
   */
  virtual void InitFuncState(LoweredFunc f);
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  // expression
  void VisitExpr_(const Variable* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Load* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Let* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Call* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Add* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Sub* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Mul* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Div* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Mod* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Min* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Max* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const EQ* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const NE* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LT* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LE* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const GT* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const GE* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const And* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Or* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Cast* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Not* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Select* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Ramp* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Broadcast* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const IntImm* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const UIntImm* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const FloatImm* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const StringImm* op, std::ostream& os) override;  // NOLINT(*)
  // statment
  void VisitStmt_(const LetStmt* op) override;
  void VisitStmt_(const Store* op) override;
  void VisitStmt_(const For* op) override;
  void VisitStmt_(const IfThenElse* op) override;
  void VisitStmt_(const Allocate* op) override;
  void VisitStmt_(const AttrStmt* op) override;
  void VisitStmt_(const AssertStmt* op) override;
  void VisitStmt_(const Evaluate* op) override;
  void VisitStmt_(const Block* op) override;
  void VisitStmt_(const ProducerConsumer* op) override;
126
  /*!
127 128
   * Print Type represetnation of type t.
   * \param t The type representation.
129
   * \param os The stream to print the ctype into
130
   */
131
  virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
132 133
  /*!
   * \brief Print expr representing the thread tag
134
   * \param IterVar iv The thread index to be binded;
135
   */
136 137
  virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
  virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
138
  virtual void PrintStorageSync(const Call* op);  // NOLINT(*)
139 140 141 142
  // Binary vector op.
  virtual void PrintVecBinaryOp(
      const std::string&op, Type op_type,
      Expr lhs, Expr rhs, std::ostream& os);  // NOLINT(*)
143
  // print vector load
144
  virtual std::string GetVecLoad(Type t, const Variable* buffer, Expr base);
145
  // print vector store
146 147 148
  virtual void PrintVecStore(const Variable* buffer,
                             Type t, Expr base,
                             const std::string& value);  // NOLINT(*)
149
  // print load of single element
150 151
  virtual void PrintVecElemLoad(
      const std::string& vec, Type t, int i, std::ostream& os);  // NOLINT(*)
152
  // print store of single element.
153 154
  virtual void PrintVecElemStore(
      const std::string& vec, Type t, int i, const std::string& value);
155 156
  // Get a cast type from to
  virtual std::string CastFromTo(std::string value, Type from, Type target);
157

158
 protected:
159 160 161
  // Print reference to struct location
  std::string GetStructRef(
      Type t, const Expr& buffer, const Expr& index, int kind);
162
  // print reference to a buffer as type t in index.
163
  virtual std::string GetBufferRef(
164
      Type t, const Variable* buffer, Expr index);
165 166 167 168 169
  /*!
   * \brief If buffer is allocated as type t.
   * \param buf_var The buffer variable.
   * \param t The type to be checked.
   */
170 171 172 173 174 175
  bool HandleTypeMatch(const Variable* buf_var, Type t) const;
  /*!
   * \brief Register the data type of buf_var
   * \param buf_var The buffer variable.
   * \param t The type to be checked.
   */
176
  void RegisterHandleType(const Variable* buf_var, Type t);
177 178 179
  // override
  void PrintSSAAssign(
      const std::string& target, const std::string& src, Type t) final;
180 181
  /*! \brief restrict keyword */
  std::string restrict_keyword_{""};
182 183
  /*! \brief the storage scope of allocation */
  std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
184 185
  /*! \brief the data type of allocated buffers */
  std::unordered_map<const Variable*, Type> handle_data_type_;
186 187 188

 private:
  /*! \brief whether to print in SSA form */
189 190 191
  bool print_ssa_form_{false};
  /*! \brief set of volatile buf access */
  std::unordered_set<const Variable*> volatile_buf_;
192 193 194 195 196
};

}  // namespace codegen
}  // namespace tvm
#endif  // TVM_CODEGEN_CODEGEN_C_H_