codegen_c.h 8.93 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*!
 *  Copyright (c) 2016 by Contributors
 * \file codegen_c.h
 * \brief Common utilities to generated C style code.
 */
#ifndef TVM_CODEGEN_CODEGEN_C_H_
#define TVM_CODEGEN_CODEGEN_C_H_

#include <tvm/ir.h>
10
#include <tvm/ir_functor_ext.h>
11 12
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
13
#include <string>
14
#include <vector>
15 16 17 18 19
#include <unordered_map>

namespace tvm {
namespace codegen {

20
using namespace ir;
21 22 23 24 25
/*!
 * \brief A base class to generate C code.
 *
 *  CodeGenC have two modes: generate SSA formed C code or normal form.
 */
26 27 28
class CodeGenC :
      public ExprFunctor<void(const Expr&, std::ostream&)>,
      public StmtFunctor<void(const Stmt&)> {
29 30
 public:
  /*!
31 32
   * \brief Initialize the code generator.
   * \param output_ssa Whether output SSA.
33
   */
34 35 36 37 38 39 40 41 42 43 44
  void Init(bool output_ssa);
  /*!
   * \brief Add the function to the generated module.
   * \param f The function to be compiled.
   */
  void AddFunction(LoweredFunc f);
  /*!
   * \brief Finalize the compilation and return the code.
   * \return The code.
   */
  std::string Finish();
45 46 47 48
  /*!
   * \brief Print the Stmt n to CodeGenC->stream
   * \param n The statement to be printed.
   */
49 50 51
  void PrintStmt(const Stmt& n) {
    VisitStmt(n);
  }
52 53 54 55 56
  /*!
   * \brief Print the expression n(or its ssa id if in ssa mode) into os
   * \param n The expression to be printed.
   * \param os The output stream
   */
57
  void PrintExpr(const Expr& n, std::ostream& os);
58 59 60 61
  /*!
   * \brief Same as PrintExpr, but simply returns result string
   * \param n The expression to be printed.
   */
62
  std::string PrintExpr(const Expr& n) {
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    std::ostringstream os;
    PrintExpr(n, os);
    return os.str();
  }
  /*! \brief print the current indented value */
  void PrintIndent();
  /*!
   * \brief Register constant value appeared in expresion tree
   *  This avoid generated a ssa id for each appearance of the value
   * \param value The constant value.
   */
  void MarkConst(std::string value);
  /*!
   * \brief Allocate a variable name for a newly defined var.
   * \param v The variable.
   * \return the variable name.
   */
  std::string AllocVarID(const Variable* v);
  /*!
   * \brief Get a variable name.
   * \param v The variable.
   * \return the variable name.
   */
  std::string GetVarID(const Variable* v) const;
87
  // The following parts are overloadable print operations.
88
  /*!
89 90 91 92
   * \brief Initialize codegen state for generating f.
   * \param f The function to be compiled.
   */
  virtual void InitFuncState(LoweredFunc f);
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  // expression
  void VisitExpr_(const Variable* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Load* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Let* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Call* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Add* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Sub* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Mul* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Div* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Mod* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Min* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Max* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const EQ* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const NE* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LT* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LE* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const GT* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const GE* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const And* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Or* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Cast* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Not* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Select* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Ramp* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const Broadcast* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const IntImm* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const UIntImm* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const FloatImm* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const StringImm* op, std::ostream& os) override;  // NOLINT(*)
  // statment
  void VisitStmt_(const LetStmt* op) override;
  void VisitStmt_(const Store* op) override;
  void VisitStmt_(const For* op) override;
  void VisitStmt_(const IfThenElse* op) override;
  void VisitStmt_(const Allocate* op) override;
  void VisitStmt_(const AttrStmt* op) override;
  void VisitStmt_(const AssertStmt* op) override;
  void VisitStmt_(const Evaluate* op) override;
  void VisitStmt_(const Block* op) override;
  void VisitStmt_(const ProducerConsumer* op) override;
133
  /*!
134 135
   * Print Type represetnation of type t.
   * \param t The type representation.
136
   * \param os The stream to print the ctype into
137 138
   */
  virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*)
139 140
  /*!
   * \brief Print expr representing the thread tag
141
   * \param tag The tag in the thread.
142 143
   * \param  os The strean to output to
   */
144 145
  virtual void PrintThreadIndexExpr(
      std::string tag, std::ostream& os); // NOLINT(*)
146
  virtual void PrintStorageScope(const std::string& scope, std::ostream& os);  // NOLINT(*)
147
  virtual void PrintStorageSync(const std::string& scope);  // NOLINT(*)
148 149 150 151
  // Binary vector op.
  virtual void PrintVecBinaryOp(
      const std::string&op, Type op_type,
      Expr lhs, Expr rhs, std::ostream& os);  // NOLINT(*)
152
  // print vector load
153 154 155
  virtual void PrintVecLoad(const Variable* buffer,
                            Type t, Expr base,
                            std::ostream& os);  // NOLINT(*)
156
  // print vector store
157 158 159
  virtual void PrintVecStore(const Variable* buffer,
                             Type t, Expr base,
                             const std::string& value);  // NOLINT(*)
160
  // print load of single element
161 162
  virtual void PrintVecElemLoad(
      const std::string& vec, Type t, int i, std::ostream& os);  // NOLINT(*)
163
  // print store of single element.
164 165
  virtual void PrintVecElemStore(
      const std::string& vec, Type t, int i, const std::string& value);
166

167
 protected:
168 169 170 171 172 173 174 175 176
  /*! \brief the stream to be printed */
  std::ostringstream stream;
  /*! \brief entry in ssa assign map */
  struct SSAEntry {
    /*! \brief The value id */
    std::string vid;
    /*! \brief The scope id */
    int scope_id;
  };
177 178 179 180
  // print reference to a buffer as type t in index.
  void PrintBufferRef(const Variable* buffer,
                      Type t, Expr index,
                      std::ostream& os);  // NOLINT(*)
181 182 183 184 185 186 187 188
  /*!
   * \brief Get the SSA ID corresponds to src
   *  If necessary, generate new assignment
   * \param src The source expression
   * \param t The type of the expression.
   */
  std::string SSAGetID(std::string src, Type t);
  /*!
189 190 191 192 193 194
   * \brief get a unique name with the corresponding prefix
   * \param prefix The prefix of the name
   * \return The returned name.
   */
  std::string GetUniqueName(std::string prefix);
  /*!
195 196 197 198 199 200 201 202 203 204
   * \brief mark the beginning of a new scope
   * \return The scope id.
   */
  int BeginScope();
  /*!
   * \brief mark the end of an old scope.
   * \param scope_id The scope id to be ended.
   */
  void EndScope(int scope_id);
  /*!
205 206 207 208
   * \brief If buffer is allocated as type t.
   * \param buf_var The buffer variable.
   * \param t The type to be checked.
   */
209 210 211 212 213 214
  bool HandleTypeMatch(const Variable* buf_var, Type t) const;
  /*!
   * \brief Register the data type of buf_var
   * \param buf_var The buffer variable.
   * \param t The type to be checked.
   */
215
  void RegisterHandleType(const Variable* buf_var, Type t);
216
  /*!
217 218 219
   * \brief Get the storage scope of buf_var.
   * \param buf_var The buf_var to be queryed.
   * \return The storage scope.
220
   */
221
  std::string GetStorageScope(const Variable* buf_var) const;
222 223
  /*! \brief the storage scope of allocation */
  std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
224 225 226 227

 private:
  /*! \brief whether to print in SSA form */
  bool print_ssa_form_{true};
228 229 230
  /*! \brief name allocation map */
  std::unordered_map<std::string, int> name_alloc_map_;
  /*! \brief assignment map of ssa */
231
  std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
232 233 234 235
  /*! \brief name of each variable */
  std::unordered_map<const Variable*, std::string> var_idmap_;
  /*! \brief the data type of allocated buffers */
  std::unordered_map<const Variable*, Type> handle_data_type_;
236 237
  /*! \brief array to check whether we are inside certain scope */
  std::vector<bool> scope_mark_;
238 239
  /*! \brief The current indentation value */
  int indent{0};
240 241 242 243 244
};

}  // namespace codegen
}  // namespace tvm
#endif  // TVM_CODEGEN_CODEGEN_C_H_