codegen_metal.h 1.29 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_metal.h
 * \brief Generate Metal device code.
 */
#ifndef TVM_CODEGEN_CODEGEN_METAL_H_
#define TVM_CODEGEN_CODEGEN_METAL_H_

#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string>
12
#include "codegen_c.h"
13 14 15 16 17 18 19 20 21 22 23 24 25

namespace tvm {
namespace codegen {

class CodeGenMetal final : public CodeGenC {
 public:
  CodeGenMetal();
  void AddFunction(LoweredFunc f);
  // override print thread tag.
  void PrintArgUnionDecl();
  void InitFuncState(LoweredFunc f) final;
  void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
  void PrintStorageSync(const Call* op) final;  // NOLINT(*)
26
  void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
27
  void BindThreadIndex(const IterVar& iv) final;  // NOLINT(*)
28 29 30 31 32 33
  // print load of single element
  void PrintVecElemLoad(
      const std::string& vec, Type t, int i, std::ostream& os) final;  // NOLINT(*)
  // print store of single element.
  void PrintVecElemStore(
      const std::string& vec, Type t, int i, const std::string& value) final;
34 35
  // overload visitor
  void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
36 37 38

 private:
  int thread_index_bits_{32};
39 40 41 42 43
};
}  // namespace codegen
}  // namespace tvm

#endif  // TVM_CODEGEN_CODEGEN_METAL_H_