Commit af9f69a7 by Yuwei Hu Committed by Tianqi Chen

[INTRIN] enable popcount on cuda, opencl, metal (#774)

parent e4a51303
......@@ -30,18 +30,14 @@ struct FloatSuffix {
}
};
// Add float suffix to the intrinsics
struct FloatDirect {
// Return the intrinsic name
struct Direct {
std::string operator()(Type t, std::string name) const {
if (t.is_float()) {
return name;
} else {
return "";
}
}
};
// Directly call pure extern function for floats.
// Call pure extern function.
template<typename T>
inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
......
......@@ -36,6 +36,19 @@ struct CUDAFastMath : public CUDAMath {
}
};
struct CUDAPopcount {
std::string operator()(Type t, std::string name) const {
if (t.lanes() == 1 && t.is_uint()) {
switch (t.bits()) {
case 32: return "__popc";
case 64: return "__popcll";
default: return "";
}
}
return "";
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);
......@@ -51,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -10,19 +10,22 @@ namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount")
.set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
......
......@@ -10,19 +10,22 @@ namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
.set_body(DispatchExtern<FloatDirect>);
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
.set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
......
......@@ -60,25 +60,40 @@ def test_log_pow_llvm():
b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
def test_popcount_llvm():
def test_popcount():
def run(dtype):
# graph
n = tvm.var('n')
A = tvm.placeholder((n,), name='A', dtype="uint32")
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A', dtype=dtype)
B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B')
s = tvm.create_schedule(B.op)
# simple schedule
num_thread = 8
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
if not tvm.module.enabled("llvm"):
def check_device(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0)
ctx = tvm.context(device, 0)
if str(ctx).startswith('gpu'):
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
func = tvm.build(s, [A, B], device)
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx)
b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx)
f(a, b)
func(a, b)
np.testing.assert_allclose(
b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)
check_device("llvm")
check_device("cuda")
check_device("opencl")
check_device("metal")
run('uint32')
run('uint64')
def test_add():
......@@ -133,5 +148,5 @@ def test_add():
if __name__ == "__main__":
test_add()
test_log_pow_llvm()
test_popcount_llvm()
test_popcount()
test_exp()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment