codegen_opencl.h 2.58 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 * \file codegen_opencl.h
22
 * \brief Generate OpenCL device code.
23 24 25 26 27 28 29
 */
#ifndef TVM_CODEGEN_CODEGEN_OPENCL_H_
#define TVM_CODEGEN_CODEGEN_OPENCL_H_

#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string>
30
#include "codegen_c.h"
31 32 33 34

namespace tvm {
namespace codegen {

35
class CodeGenOpenCL final : public CodeGenC {
36
 public:
37
  CodeGenOpenCL();
38
  void AddFunction(LoweredFunc f);
39 40
  std::string Finish();

41
  // override print thread tag.
42
  void InitFuncState(LoweredFunc f) final;
43
  void BindThreadIndex(const IterVar& iv) final;  // NOLINT(*)
44
  void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
45
  void PrintStorageSync(const CallNode* op) final;  // NOLINT(*)
46
  void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
47
  std::string GetVecLoad(DataType t, const VarNode* buffer,
48
                         PrimExpr base) final;
49
  void PrintVecStore(const VarNode* buffer,
50
                     DataType t, PrimExpr base,
51
                     const std::string& value) final;  // NOLINT(*)
52
  // the address of load/store
53
  void PrintVecAddr(const VarNode* buffer, DataType t,
54
                    PrimExpr base, std::ostream& os);  // NOLINT(*)
55
  std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
56

57
  // overload visitor
58 59 60 61
  void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
  void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
  void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
  void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
62 63 64 65 66

 private:
  // whether enable fp16 and fp64 extension
  bool enable_fp16_{false};
  bool enable_fp64_{false};
67 68 69 70 71 72
};

}  // namespace codegen
}  // namespace tvm

#endif  // TVM_CODEGEN_CODEGEN_OPENCL_H_