codegen_cuda.h 1.88 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_cuda.h
 * \brief Utility to generate cuda code
 */
#ifndef TVM_CODEGEN_CODEGEN_CUDA_H_
#define TVM_CODEGEN_CODEGEN_CUDA_H_

#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string>
12
#include "codegen_c.h"
13 14 15 16

namespace tvm {
namespace codegen {

17
class CodeGenCUDA final : public CodeGenC {
18
 public:
19
  CodeGenCUDA();
20
  void Init(bool output_ssa);
21
  void AddFunction(LoweredFunc f);
22
  std::string Finish();
23
  bool need_include_path() { return (enable_fp16_ || enable_int8_); }
24
  // override behavior
25
  void VisitStmt_(const ir::For* op) final;
26
  void PrintStorageSync(const Call* op) final;
27
  void PrintStorageScope(const std::string& scope, std::ostream& os) final;  // NOLINT(*)
28 29 30
  void PrintVecBinaryOp(
      const std::string&op, Type t,
      Expr lhs, Expr rhs, std::ostream& os) final;  // NOLINT(*)
31
  void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
32 33 34 35
  void PrintVecElemLoad(
      const std::string& vec, Type t, int i, std::ostream& os) final;  // NOLINT(*)
  void PrintVecElemStore(
      const std::string& vec, Type t, int i, const std::string& value) final;
36
  void BindThreadIndex(const IterVar& iv) final;  // NOLINT(*)
37
  // overload visitor
38
  void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
39
  void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
40
  void VisitExpr_(const FloatImm *op, std::ostream& os) final;
41 42
  void VisitStmt_(const Evaluate *op) final;

43
 private:
44 45 46 47 48 49
  // Whether global barrier is needed.
  bool need_global_barrier_{false};
  // Global barrier state
  std::string vid_global_barrier_state_;
  // Global barrier expected node.
  std::string vid_global_barrier_expect_;
50 51
  // whether enable fp16
  bool enable_fp16_{false};
52 53
  // whether enable int8
  bool enable_int8_{false};
54 55 56 57 58 59
};

}  // namespace codegen
}  // namespace tvm

#endif  // TVM_CODEGEN_CODEGEN_CUDA_H_