diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc
index 43989a6..84f3cb5 100644
--- a/src/pass/lower_intrin.cc
+++ b/src/pass/lower_intrin.cc
@@ -40,11 +40,11 @@ class IntrinInjecter : public IRMutator {
     if (const Mul* mb = op->b.as<Mul>()) {
       Expr r = (*fma_)(Call::make(
           op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
-      if (r.defined()) return r;
+      if (r.defined()) return this->Mutate(r);
     } else if (const Mul* ma = op->a.as<Mul>()) {
       Expr r = (*fma_)(Call::make(
           op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
-      if (r.defined()) return r;
+      if (r.defined()) return this->Mutate(r);
     }
     return IRMutator::Mutate_(op, e);
   }
diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py
index cb17130..0db06b9 100644
--- a/tests/python/unittest/test_codegen_llvm.py
+++ b/tests/python/unittest/test_codegen_llvm.py
@@ -65,7 +65,7 @@ def test_llvm_persist_parallel():
     n = 128
     A = tvm.placeholder((n,), name='A')
     B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
-    C = tvm.compute(A.shape, lambda *i: B(*i) + 2, name='C')
+    C = tvm.compute(A.shape, lambda *i: tvm.sqrt(B(*i)) * 2 + 2, name='C')
     s = tvm.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=8)
     xo1, xo2 = s[C].split(xo, nparts=1)
@@ -86,7 +86,9 @@ def test_llvm_persist_parallel():
         a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         f(a, c)
-        np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 3)
+        np.testing.assert_allclose(c.asnumpy(),
+                                   np.sqrt(a.asnumpy() + 1) * 2 + 2,
+                                   rtol=1e-5)
 
     check_llvm()