/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_arm.cc
 * \brief ARM specific code generator
 */
#ifdef TVM_LLVM_VERSION
#include "codegen_cpu.h"

namespace tvm {
namespace codegen {

// ARM specific code generator, this is used as an example on
// how to override behavior llvm code generator for specific target
class CodeGenARM final : public CodeGenCPU {
 public:
  void InitTarget(llvm::TargetMachine* tm) final {
    // set native vector bits.
    native_vector_bits_ = 16 * 8;
    CodeGenCPU::InitTarget(tm);
  }
  llvm::Value* CreateIntrinsic(const Call* op) override;

 private:
  Expr ARMPopcount(const Call* op);
};

llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) {
  if (op->is_intrinsic("llvm_intrin")) {
    llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
        op->args[0].as<UIntImm>()->value);
    if (id == ::llvm::Intrinsic::ctpop) {
      Expr e = ARMPopcount(op);
      return CodeGenCPU::CreateIntrinsic(e.as<Call>());
    }
  }
  return CodeGenCPU::CreateIntrinsic(op);
}

Expr CodeGenARM::ARMPopcount(const Call *call) {
  using namespace ir;
  const Expr& e = call->args[2];
  ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
  ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;

  // Fallback to default llvm lowering rule if input type not a full vector or half vector length
  int total_size =  call->type.bits() * call->type.lanes();
  if (!call->type.is_vector() || call->type.bits() == 8 ||
     (total_size != 128 && total_size != 64)) {
    Array<Expr> vcnt_args;
    vcnt_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
    vcnt_args.push_back(ir::UIntImm::make(UInt(32), 1));
    vcnt_args.push_back(e);
    return ir::Call::make(call->type,  "llvm_intrin", vcnt_args, Call::PureIntrinsic);
  }

  // Popcount lowering rule:
  // Reinterpret input vector as a vector of 8bit values and preform popcount
  // Pairwise add between adjacent elements and double width with vpaddlu
  // to return back to original input type

  // Dvisions are always divisible (number of bits = 64 or 128)
  Type uint8_type = Type(e.type().code(), 8, e.type().bits() * e.type().lanes() / 8);
  Type uint16_type = Type(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
  Type uint32_type = Type(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);

  // Interpret input as vector of 8bit values
  Expr input8 = reinterpret(uint8_type, e);
  // Popcount 8bit->8bit
  const Call* c0 = input8.as<Call>();
  CHECK(c0 != nullptr);
  Array<Expr> vcnt8_args;
  vcnt8_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
  vcnt8_args.push_back(ir::UIntImm::make(UInt(32), 1));
  vcnt8_args.push_back(input8);
  Expr vcnt8 = ir::Call::make(uint8_type,  "llvm_intrin", vcnt8_args, Call::PureIntrinsic);

  // Accumulation 8->16bit
  Array<Expr> vcnt16_args;
  vcnt16_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
  vcnt16_args.push_back(ir::UIntImm::make(UInt(32), 1));
  vcnt16_args.push_back(vcnt8);
  Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
  if (call->type.bits() == 16) {
    return vcnt16;
  }

  // Accumulation 16->32bit
  Array<Expr> vcnt32_args;
  vcnt32_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
  vcnt32_args.push_back(ir::UIntImm::make(UInt(32), 1));
  vcnt32_args.push_back(vcnt16);
  Expr vcnt32 = ir::Call::make(uint32_type,  "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
  if (call->type.bits() == 32) {
    return vcnt32;
  }

  // Accumulation 32->64bit
  Array<Expr> vcnt64_args;
  vcnt64_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
  vcnt64_args.push_back(ir::UIntImm::make(UInt(32), 1));
  vcnt64_args.push_back(vcnt32);
  return ir::Call::make(call->type,  "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
}

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
    CodeGenLLVM* cg = new CodeGenARM();
    *rv = static_cast<void*>(cg);
  });

}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION