Unverified Commit 90eee087 by Tianqi Chen Committed by GitHub

[TEST] Fix testcase to make them more compatible to zero-rank (#3612)

parent 814554e0
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
import topi
from tvm.contrib import util, clang from tvm.contrib import util, clang
import numpy as np import numpy as np
import ctypes import ctypes
...@@ -349,8 +350,8 @@ def test_rank_zero(): ...@@ -349,8 +350,8 @@ def test_rank_zero():
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
scale = tvm.placeholder((), name='scale') scale = tvm.placeholder((), name='scale')
k = tvm.reduce_axis((0, n), name="k") k = tvm.reduce_axis((0, n), name="k")
C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C") C = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k), name="C")
D = tvm.compute((), lambda : C + 1) D = tvm.compute((), lambda : C() + 1)
s = tvm.create_schedule(D.op) s = tvm.create_schedule(D.op)
# build and invoke the kernel. # build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm") f = tvm.build(s, [A, scale, D], "llvm")
...@@ -373,8 +374,8 @@ def test_rank_zero_bound_checkers(): ...@@ -373,8 +374,8 @@ def test_rank_zero_bound_checkers():
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
scale = tvm.placeholder((), name='scale') scale = tvm.placeholder((), name='scale')
k = tvm.reduce_axis((0, n), name="k") k = tvm.reduce_axis((0, n), name="k")
C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C") C = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k), name="C")
D = tvm.compute((), lambda : C + 1) D = tvm.compute((), lambda : C() + 1)
s = tvm.create_schedule(D.op) s = tvm.create_schedule(D.op)
# build and invoke the kernel. # build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm") f = tvm.build(s, [A, scale, D], "llvm")
......
...@@ -79,6 +79,8 @@ def _make_bop(broadcast_bop, orig_bop): ...@@ -79,6 +79,8 @@ def _make_bop(broadcast_bop, orig_bop):
tvm.Expr (otherwise) tvm.Expr (otherwise)
The result of {op} operation. The result of {op} operation.
""" """
print(lhs, type(lhs))
print(rhs, type(rhs))
if not isinstance(lhs, tvm.tensor.Tensor) and not isinstance(rhs, tvm.tensor.Tensor): if not isinstance(lhs, tvm.tensor.Tensor) and not isinstance(rhs, tvm.tensor.Tensor):
return orig_bop(lhs, rhs) return orig_bop(lhs, rhs)
return broadcast_bop(lhs, rhs) return broadcast_bop(lhs, rhs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment