Commit af9f69a7 by Yuwei Hu Committed by Tianqi Chen

[INTRIN] enable popcount on cuda, opencl, metal (#774)

parent e4a51303
...@@ -30,18 +30,14 @@ struct FloatSuffix { ...@@ -30,18 +30,14 @@ struct FloatSuffix {
} }
}; };
// Add float suffix to the intrinsics // Return the intrinsic name
struct FloatDirect { struct Direct {
std::string operator()(Type t, std::string name) const { std::string operator()(Type t, std::string name) const {
if (t.is_float()) { return name;
return name;
} else {
return "";
}
} }
}; };
// Directly call pure extern function for floats. // Call pure extern function.
template<typename T> template<typename T>
inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0]; Expr e = args[0];
......
...@@ -36,6 +36,19 @@ struct CUDAFastMath : public CUDAMath { ...@@ -36,6 +36,19 @@ struct CUDAFastMath : public CUDAMath {
} }
}; };
struct CUDAPopcount {
std::string operator()(Type t, std::string name) const {
if (t.lanes() == 1 && t.is_uint()) {
switch (t.bits()) {
case 32: return "__popc";
case 64: return "__popcll";
default: return "";
}
}
return "";
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
...@@ -51,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt") ...@@ -51,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -10,19 +10,22 @@ namespace codegen { ...@@ -10,19 +10,22 @@ namespace codegen {
namespace intrin { namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount")
.set_body(DispatchExtern<Direct>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
......
...@@ -10,19 +10,22 @@ namespace codegen { ...@@ -10,19 +10,22 @@ namespace codegen {
namespace intrin { namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
.set_body(DispatchExtern<Direct>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
......
...@@ -60,25 +60,40 @@ def test_log_pow_llvm(): ...@@ -60,25 +60,40 @@ def test_log_pow_llvm():
b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5) b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
def test_popcount_llvm(): def test_popcount():
# graph def run(dtype):
n = tvm.var('n') # graph
A = tvm.placeholder((n,), name='A', dtype="uint32") n = tvm.convert(1024)
B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B') A = tvm.placeholder((n,), name='A', dtype=dtype)
s = tvm.create_schedule(B.op) B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B')
s = tvm.create_schedule(B.op)
# simple schedule
num_thread = 8
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
if not tvm.module.enabled("llvm"): def check_device(device):
return if not tvm.module.enabled(device):
f = tvm.build(s, [A, B], "llvm") print("skip because %s is not enabled.." % device)
ctx = tvm.cpu(0) return
# launch the kernel. ctx = tvm.context(device, 0)
n = 1024 if str(ctx).startswith('gpu'):
a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx) s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx) s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
f(a, b) func = tvm.build(s, [A, B], device)
np.testing.assert_allclose( # launch the kernel.
b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5) n = 1024
a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx)
b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx)
func(a, b)
np.testing.assert_allclose(
b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)
check_device("llvm")
check_device("cuda")
check_device("opencl")
check_device("metal")
run('uint32')
run('uint64')
def test_add(): def test_add():
...@@ -133,5 +148,5 @@ def test_add(): ...@@ -133,5 +148,5 @@ def test_add():
if __name__ == "__main__": if __name__ == "__main__":
test_add() test_add()
test_log_pow_llvm() test_log_pow_llvm()
test_popcount_llvm() test_popcount()
test_exp() test_exp()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment