Commit 32f158e6 by xqdan Committed by Tianqi Chen

[intrin]support fmod for cuda (#1964)

parent 86f4f1b7
......@@ -376,6 +376,22 @@ def popcount(x):
"""
return call_pure_intrin(x.dtype, "popcount", x)
def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
Parameters
----------
x : Expr
Input argument.
y : Expr
Input argument.
Returns
-------
z : Expr
The result.
"""
return call_pure_intrin(x.dtype, "fmod", x, y)
# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
......
......@@ -91,6 +91,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>);
} // namespace intrin
} // namespace codegen
......
......@@ -42,6 +42,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod")
.set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -42,6 +42,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod")
.set_body(DispatchExtern<Direct>);
// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
......
......@@ -450,4 +450,10 @@ Expr prod(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
Expr fmod(Expr x, Expr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.type().is_float()) << "fmod only applies to float";
return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic);
}
} // namespace tvm
......@@ -38,6 +38,45 @@ def test_exp():
check_device("cuda", "llvm")
check_device("vulkan")
def test_fmod():
# graph
def run(dtype):
n = tvm.var('n')
A = tvm.placeholder((n,), name='A', dtype=dtype)
B = tvm.placeholder((n,), name='B', dtype=dtype)
C = tvm.compute(A.shape, lambda *i: tvm.fmod(A(*i), B(*i)), name='C')
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
num_thread = 8
bx, tx = s[C].split(C.op.axis[0], factor=num_thread)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
target = tvm.target.create(device)
if "cpu" not in target.keys:
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
fmod = tvm.build(s, [A, B, C], device, name="myfmod")
# launch the kernel.
n = 1024
a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx)
b = tvm.nd.array((np.random.uniform(size=n) * 256).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
ftimer = fmod.time_evaluator(fmod.entry_name, ctx, number=1)
tcost = ftimer(a, b, c).mean
#fmod(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.mod(a.asnumpy(), b.asnumpy()), rtol=1e-5)
check_device("cuda")
check_device("opencl -device=intel_graphics")
check_device("metal")
run("float32")
def test_multiple_cache_write():
# graph
......@@ -245,3 +284,4 @@ if __name__ == "__main__":
test_add()
test_log_pow_llvm()
test_popcount()
test_fmod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment