Commit 5ea4072c by Tianqi Chen Committed by GitHub

[PYTHON] Allow general types (#425)

parent df3c996b
......@@ -65,10 +65,7 @@ class TVMType(ctypes.Structure):
head = ""
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = int(head) if head else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
......
......@@ -231,6 +231,7 @@ def test_multiple_func():
check_llvm()
def test_llvm_select():
def check_llvm(n, offset):
if not tvm.module.enabled("llvm"):
......@@ -251,7 +252,27 @@ def test_llvm_select():
check_llvm(64, 8)
def test_llvm_bool():
def check_llvm(n):
if not tvm.module.enabled("llvm"):
return
A = tvm.placeholder((n, ), name='A', dtype="int32")
C = tvm.compute((n,), lambda i: A[i].equal(1).astype("float"), name='C')
s = tvm.create_schedule(C.op)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
c = tvm.nd.empty((n,), C.dtype, ctx)
f(a, c)
c_np = a.asnumpy() == 1
np.testing.assert_allclose(c.asnumpy(), c_np)
check_llvm(64)
if __name__ == "__main__":
test_llvm_bool()
test_llvm_persist_parallel()
test_llvm_select()
test_llvm_vadd_pipeline()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment