diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 43989a6..84f3cb5 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -40,11 +40,11 @@ class IntrinInjecter : public IRMutator { if (const Mul* mb = op->b.as<Mul>()) { Expr r = (*fma_)(Call::make( op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic)); - if (r.defined()) return r; + if (r.defined()) return this->Mutate(r); } else if (const Mul* ma = op->a.as<Mul>()) { Expr r = (*fma_)(Call::make( op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic)); - if (r.defined()) return r; + if (r.defined()) return this->Mutate(r); } return IRMutator::Mutate_(op, e); } diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index cb17130..0db06b9 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -65,7 +65,7 @@ def test_llvm_persist_parallel(): n = 128 A = tvm.placeholder((n,), name='A') B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B') - C = tvm.compute(A.shape, lambda *i: B(*i) + 2, name='C') + C = tvm.compute(A.shape, lambda *i: tvm.sqrt(B(*i)) * 2 + 2, name='C') s = tvm.create_schedule(C.op) xo, xi = s[C].split(C.op.axis[0], factor=8) xo1, xo2 = s[C].split(xo, nparts=1) @@ -86,7 +86,9 @@ def test_llvm_persist_parallel(): a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) f(a, c) - np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 3) + np.testing.assert_allclose(c.asnumpy(), + np.sqrt(a.asnumpy() + 1) * 2 + 2, + rtol=1e-5) check_llvm()