codegen_arm.cc 4.98 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file codegen_arm.cc
 * \brief ARM specific code generator
 */
#ifdef TVM_LLVM_VERSION
25 26 27

#include <tvm/runtime/registry.h>

28
#include "codegen_cpu.h"
29 30 31 32 33 34

namespace tvm {
namespace codegen {

// ARM specific code generator, this is used as an example on
// how to override behavior llvm code generator for specific target
35
class CodeGenARM final : public CodeGenCPU {
36 37 38 39
 public:
  void InitTarget(llvm::TargetMachine* tm) final {
    // set native vector bits.
    native_vector_bits_ = 16 * 8;
40
    CodeGenCPU::InitTarget(tm);
41
  }
42
  llvm::Value* CreateIntrinsic(const CallNode* op) override;
43 44

 private:
45
  PrimExpr ARMPopcount(const CallNode* op);
46 47
};

48
llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
49 50
  if (op->is_intrinsic("llvm_intrin")) {
    llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
51
        Downcast<IntImm>(op->args[0])->value);
52
    if (id == ::llvm::Intrinsic::ctpop) {
53
      PrimExpr e = ARMPopcount(op);
54
      return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
55 56 57 58 59
    }
  }
  return CodeGenCPU::CreateIntrinsic(op);
}

60
PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
61
  using namespace tir;
62
  const PrimExpr& e = call->args[2];
63 64 65 66
  ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
  ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;

  // Fallback to default llvm lowering rule if input type not a full vector or half vector length
67 68
  int total_size =  call->dtype.bits() * call->dtype.lanes();
  if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
69
     (total_size != 128 && total_size != 64)) {
70
    Array<PrimExpr> vcnt_args;
71 72
    vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
    vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
73
    vcnt_args.push_back(e);
74
    return tir::CallNode::make(call->dtype,  "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
75 76 77 78 79 80 81 82
  }

  // Popcount lowering rule:
  // Reinterpret input vector as a vector of 8bit values and preform popcount
  // Pairwise add between adjacent elements and double width with vpaddlu
  // to return back to original input type

  // Dvisions are always divisible (number of bits = 64 or 128)
83 84 85 86 87 88
  DataType uint8_type = DataType(
      e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8);
  DataType uint16_type = DataType(
      uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
  DataType uint32_type = DataType(
      uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
89 90

  // Interpret input as vector of 8bit values
91
  PrimExpr input8 = reinterpret(uint8_type, e);
92
  // Popcount 8bit->8bit
93
  const CallNode* c0 = input8.as<CallNode>();
94
  CHECK(c0 != nullptr);
95
  Array<PrimExpr> vcnt8_args;
96 97
  vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
  vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
98
  vcnt8_args.push_back(input8);
99
  PrimExpr vcnt8 = tir::CallNode::make(
100
    uint8_type,  "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
101 102

  // Accumulation 8->16bit
103
  Array<PrimExpr> vcnt16_args;
104 105
  vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
  vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
106
  vcnt16_args.push_back(vcnt8);
107
  PrimExpr vcnt16 = tir::CallNode::make(
108
    uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
109
  if (call->dtype.bits() == 16) {
110 111 112 113
    return vcnt16;
  }

  // Accumulation 16->32bit
114
  Array<PrimExpr> vcnt32_args;
115 116
  vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
  vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
117
  vcnt32_args.push_back(vcnt16);
118
  PrimExpr vcnt32 = tir::CallNode::make(
119
    uint32_type,  "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
120
  if (call->dtype.bits() == 32) {
121 122 123 124
    return vcnt32;
  }

  // Accumulation 32->64bit
125
  Array<PrimExpr> vcnt64_args;
126 127
  vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
  vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
128
  vcnt64_args.push_back(vcnt32);
129
  return tir::CallNode::make(
130
    call->dtype,  "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
131 132
}

133 134 135 136 137 138 139 140 141
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
    CodeGenLLVM* cg = new CodeGenARM();
    *rv = static_cast<void*>(cg);
  });

}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION