codegen_stackvm.h 4.67 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_stack_vm.h
 * \brief Codegen into Simple Stack VM.
 */
6 7
#ifndef TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
#define TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
8 9

#include <tvm/ir.h>
10
#include <tvm/ir_functor_ext.h>
11
#include <tvm/lowered_func.h>
12 13
#include <tvm/codegen.h>
#include <string>
14
#include <vector>
15 16
#include <unordered_map>

17
#include "../../runtime/stackvm/stackvm.h"
18 19 20 21

namespace tvm {
namespace codegen {

22
using namespace ir;
23 24
using runtime::StackVM;

25 26 27 28 29
/*!
 * \brief A base class to generate a stack VM.
 *  This module is used to generate host wrapper
 *  into device function when only device JIT is available.
 */
30 31 32
class CodeGenStackVM
    : public ExprFunctor<void(const Expr&)>,
      public StmtFunctor<void(const Stmt&)> {
33
 public:
34
 /*!
35 36
   * \brief Generate a stack VM representing
   * \param f The function to be compiled
37
   * \param device_funcs The extern device functions to be linked.
38 39 40
   * \note Only call compile once,
   *  create a new codegen object each time.
   */
41
  StackVM Compile(LoweredFunc f);
42 43
  /*! \brief Push stmt to generate new code */
  void Push(const Stmt& n);
44 45 46 47
  /*! \brief Push expr to generate new code */
  void Push(const Expr& n) {
    VisitExpr(n);
  }
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
  /*!
   * \brief Push the opcode to the code.
   * \param opcode The code to be pushed.
   */
  void PushOp(StackVM::OpCode opcode);
  /*!
   * \brief Push the opcode and operand to the code.
   * \param opcode The opcode.
   * \param operand The operand to be pushed.
   * \return operand_index, indicating location of operand
   */
  int64_t PushOp(StackVM::OpCode opcode, int operand);
  /*!
   * \brief Set the relative jump offset to be offset.
   * \param operand_index The indexed returned by PushOp.
   * \param operand The operand to be set.
   */
  void SetOperand(int64_t operand_index, int64_t operand);
  /*! \return The current program pointer */
  int64_t GetPC() const {
    return static_cast<int64_t>(vm_.code.size());
  }
  /*!
   * \brief Get string id in vm
   * \param key The string to get id.
   * \return the id of the string.
   */
  int GetStrID(const std::string& key);
  /*!
   * \brief Allocate a variable name for a newly defined var.
   * \param v The variable.
   * \return the heap index of the var.
   */
  int AllocVarID(const Variable* v);
  /*!
   * \brief Get a variable name.
   * \param v The variable.
   * \return the heap index of the var.
   */
  int GetVarID(const Variable* v) const;
88 89 90 91 92 93
  // Push binary operator
  void PushBinary(StackVM::OpCode op_int64,
                  const Expr& a,
                  const Expr& b);
  // push cast;
  void PushCast(Type dst, Type src);
94
  // overloadable functions
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
  // expression
  void VisitExpr_(const Variable* op) final;
  void VisitExpr_(const Load* op) final;
  void VisitExpr_(const Let* op) final;
  void VisitExpr_(const Call* op) final;
  void VisitExpr_(const Add* op) final;
  void VisitExpr_(const Sub* op) final;
  void VisitExpr_(const Mul* op) final;
  void VisitExpr_(const Div* op) final;
  void VisitExpr_(const Mod* op) final;
  void VisitExpr_(const Min* op) final;
  void VisitExpr_(const Max* op) final;
  void VisitExpr_(const EQ* op) final;
  void VisitExpr_(const NE* op) final;
  void VisitExpr_(const LT* op) final;
  void VisitExpr_(const LE* op) final;
  void VisitExpr_(const GT* op) final;
  void VisitExpr_(const GE* op) final;
  void VisitExpr_(const And* op) final;
  void VisitExpr_(const Or* op) final;
  void VisitExpr_(const Cast* op) final;
  void VisitExpr_(const Not* op) final;
  void VisitExpr_(const Select* op) final;
  void VisitExpr_(const Ramp* op) final;
  void VisitExpr_(const Broadcast* op) final;
  void VisitExpr_(const IntImm* op) final;
  void VisitExpr_(const UIntImm* op) final;
  void VisitExpr_(const FloatImm* op) final;
  void VisitExpr_(const StringImm* op) final;
  // statment
  void VisitStmt_(const LetStmt* op) final;
  void VisitStmt_(const Store* op) final;
  void VisitStmt_(const For* op) final;
  void VisitStmt_(const IfThenElse* op) final;
  void VisitStmt_(const Allocate* op) final;
  void VisitStmt_(const AttrStmt* op) final;
  void VisitStmt_(const AssertStmt* op) final;
  void VisitStmt_(const Evaluate* op) final;
  void VisitStmt_(const Block* op) final;
  void VisitStmt_(const ProducerConsumer* op) final;
135 136 137 138 139 140 141 142 143

 private:
  bool debug_{false};
  /*! \brief The vm to be generated */
  StackVM vm_;
  /*! \brief id of each variable */
  std::unordered_map<const Variable*, int> var_idmap_;
  /*! \brief id of each string */
  std::unordered_map<std::string, int> str_idmap_;
144
  /*! \brief id of each global function */
145
  std::unordered_map<std::string, int> extern_fun_idmap_;
146 147 148 149
};

}  // namespace codegen
}  // namespace tvm
150
#endif  // TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_