Commit f3f406ab by Meghan Cowan Committed by Tianqi Chen

[CODEGEN] ARM Popcount lowering rule and codegen updates (#1235)

parent dc6203c2
...@@ -18,8 +18,90 @@ class CodeGenARM final : public CodeGenCPU { ...@@ -18,8 +18,90 @@ class CodeGenARM final : public CodeGenCPU {
native_vector_bits_ = 16 * 8; native_vector_bits_ = 16 * 8;
CodeGenCPU::InitTarget(tm); CodeGenCPU::InitTarget(tm);
} }
llvm::Value* CreateIntrinsic(const Call* op) override;
private:
Expr ARMPopcount(const Call* op);
}; };
llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
if (id == ::llvm::Intrinsic::ctpop) {
Expr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<Call>());
}
}
return CodeGenCPU::CreateIntrinsic(op);
}
Expr CodeGenARM::ARMPopcount(const Call *call) {
using namespace ir;
const Expr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;
// Fallback to default llvm lowering rule if input type not a full vector or half vector length
int total_size = call->type.bits() * call->type.lanes();
if (!call->type.is_vector() || call->type.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<Expr> vcnt_args;
vcnt_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
vcnt_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt_args.push_back(e);
return ir::Call::make(call->type, "llvm_intrin", vcnt_args, Call::PureIntrinsic);
}
// Popcount lowering rule:
// Reinterpret input vector as a vector of 8bit values and preform popcount
// Pairwise add between adjacent elements and double width with vpaddlu
// to return back to original input type
// Dvisions are always divisible (number of bits = 64 or 128)
Type uint8_type = Type(e.type().code(), 8, e.type().bits() * e.type().lanes() / 8);
Type uint16_type = Type(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
Type uint32_type = Type(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
// Interpret input as vector of 8bit values
Expr input8 = reinterpret(uint8_type, e);
// Popcount 8bit->8bit
const Call* c0 = input8.as<Call>();
CHECK(c0 != nullptr);
Array<Expr> vcnt8_args;
vcnt8_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
vcnt8_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt8_args.push_back(input8);
Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic);
// Accumulation 8->16bit
Array<Expr> vcnt16_args;
vcnt16_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt16_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt16_args.push_back(vcnt8);
Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
if (call->type.bits() == 16) {
return vcnt16;
}
// Accumulation 16->32bit
Array<Expr> vcnt32_args;
vcnt32_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt32_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt32_args.push_back(vcnt16);
Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
if (call->type.bits() == 32) {
return vcnt32;
}
// Accumulation 32->64bit
Array<Expr> vcnt64_args;
vcnt64_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt64_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt64_args.push_back(vcnt32);
return ir::Call::make(call->type, "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) { .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenARM(); CodeGenLLVM* cg = new CodeGenARM();
......
...@@ -366,7 +366,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { ...@@ -366,7 +366,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements()); int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (extent == num_elems && begin == 0) return vec; if (extent == num_elems && begin == 0) return vec;
CHECK_LT(begin + extent, num_elems); CHECK_LE(begin + extent, num_elems);
std::vector<unsigned> indices; std::vector<unsigned> indices;
for (int i = 0; i < extent; ++i) { for (int i = 0; i < extent; ++i) {
indices.push_back(begin + i); indices.push_back(begin + i);
...@@ -562,6 +562,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { ...@@ -562,6 +562,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
sig_type.push_back(arg_value.back()->getType()); sig_type.push_back(arg_value.back()->getType());
} }
} }
llvm::Type *return_type = LLVMType(op->type);
if (sig_type.size() > 0 && return_type != sig_type[0]) {
sig_type.insert(sig_type.begin(), return_type);
}
llvm::Function* f = llvm::Intrinsic::getDeclaration( llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, sig_type); module_.get(), id, sig_type);
return builder_->CreateCall(f, arg_value); return builder_->CreateCall(f, arg_value);
...@@ -628,6 +632,26 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { ...@@ -628,6 +632,26 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
value->addIncoming(then_value, then_value_block); value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_value_block); value->addIncoming(else_value, else_value_block);
return value; return value;
} else if (op->is_intrinsic(Call::reinterpret)) {
llvm::Type * target = LLVMType(op->type);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic("vectorlow")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
return CreateVecSlice(v, 0, l/2);
} else if (op->is_intrinsic("vectorhigh")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
return CreateVecSlice(v, l/2, l/2);
} else if (op->is_intrinsic("vectorcombine")) {
llvm::Value *v0 = MakeValue(op->args[0]);
llvm::Value *v1 = MakeValue(op->args[1]);
int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
std::vector<unsigned> indices;
for (int i = 0; i < num_elems; ++i) {
indices.push_back(i);
}
return builder_->CreateShuffleVector(v0, v1, indices);
} else { } else {
LOG(FATAL) << "unknown intrinsic " << op->name; LOG(FATAL) << "unknown intrinsic " << op->name;
return nullptr; return nullptr;
......
...@@ -117,11 +117,41 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -117,11 +117,41 @@ class LLVMModuleNode final : public runtime::ModuleNode {
} }
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
std::string fmt = runtime::GetFileFormat("", format);
std::string type_str; std::string type_str;
llvm::raw_string_ostream rso(type_str); llvm::SmallString<256> str;
CHECK(mptr_ != nullptr); llvm::raw_svector_ostream rso(str);
mptr_->print(rso, nullptr);
return rso.str(); if (fmt == "s" || fmt == "asm") {
#if TVM_LLVM_VERSION <= 60
std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
#else
std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
llvm::legacy::PassManager pass;
CHECK(tm_);
#if TVM_LLVM_VERSION <= 60
CHECK(tm_->addPassesToEmitFile(
pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#else
CHECK(tm_->addPassesToEmitFile(
pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#endif
pass.run(*m);
return rso.str().str();
} else if (fmt == "" || fmt == "ll") {
std::string type_str;
llvm::raw_string_ostream rso(type_str);
CHECK(mptr_ != nullptr);
mptr_->print(rso, nullptr);
return rso.str();
} else {
LOG(FATAL) << "Do not know how to get source code with format: "
<< format << "\'";
}
return "";
} }
void Init(const Array<LoweredFunc>& funcs, std::string target) { void Init(const Array<LoweredFunc>& funcs, std::string target) {
......
import tvm
import re
import os
import ctypes
def test_popcount():
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
def check_correct_assembly(type, elements, counts):
n = tvm.convert(elements)
A = tvm.placeholder(n, dtype=type, name='A')
B = tvm.compute(A.shape, lambda i: tvm.popcount(A[i]), name='B')
s = tvm.create_schedule(B.op)
s[B].vectorize(s[B].op.axis[0])
f = tvm.build(s, [A, B], target)
# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
assembly = f.get_source('asm')
matches = re.findall("vpaddl", assembly)
assert (len(matches) == counts)
matches = re.findall("vcnt", assembly)
assert (len(matches) == 1)
check_correct_assembly('uint16', 8, 1)
check_correct_assembly('uint16', 4, 1)
check_correct_assembly('uint32', 4, 2)
check_correct_assembly('uint32', 2, 2)
check_correct_assembly('uint64', 2, 3)
if __name__ == "__main__":
test_popcount()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment