Commit 3a0d3a39 by Tianqi Chen Committed by GitHub

[PASS] Fix intrinsic lowering with fma and other intrin (#457)

* [PASS] Fix intrinsic lowering with fma and other intrin

* relax rtol for sqrt
parent 5ae1a079
......@@ -40,11 +40,11 @@ class IntrinInjecter : public IRMutator {
if (const Mul* mb = op->b.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
if (r.defined()) return r;
if (r.defined()) return this->Mutate(r);
} else if (const Mul* ma = op->a.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
if (r.defined()) return r;
if (r.defined()) return this->Mutate(r);
}
return IRMutator::Mutate_(op, e);
}
......
......@@ -65,7 +65,7 @@ def test_llvm_persist_parallel():
n = 128
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
C = tvm.compute(A.shape, lambda *i: B(*i) + 2, name='C')
C = tvm.compute(A.shape, lambda *i: tvm.sqrt(B(*i)) * 2 + 2, name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=8)
xo1, xo2 = s[C].split(xo, nparts=1)
......@@ -86,7 +86,9 @@ def test_llvm_persist_parallel():
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
f(a, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 3)
np.testing.assert_allclose(c.asnumpy(),
np.sqrt(a.asnumpy() + 1) * 2 + 2,
rtol=1e-5)
check_llvm()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment