Commit 2112a1f9 by ziheng Committed by Tianqi Chen

[FIX] Miss kUInt in TypeCode2Str & dir method (#130)

* [FIX] Miss kUInt in TypeCode2Str & dir method

* [FIX] Add regression test
parent df44566c
...@@ -521,6 +521,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -521,6 +521,7 @@ class TVMRetValue : public TVMPODValue_ {
inline const char* TypeCode2Str(int type_code) { inline const char* TypeCode2Str(int type_code) {
switch (type_code) { switch (type_code) {
case kInt: return "int"; case kInt: return "int";
case kUInt: return "uint";
case kFloat: return "float"; case kFloat: return "float";
case kStr: return "str"; case kStr: return "str";
case kBytes: return "bytes"; case kBytes: return "bytes";
......
...@@ -6,7 +6,7 @@ import ctypes ...@@ -6,7 +6,7 @@ import ctypes
import sys import sys
from .. import _api_internal from .. import _api_internal
from .node_generic import NodeGeneric, convert_to_node, const from .node_generic import NodeGeneric, convert_to_node, const
from .base import _LIB, check_call, c_str, _FFI_MODE from .base import _LIB, check_call, c_str, py_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try: try:
......
...@@ -71,6 +71,15 @@ def test_stmt(): ...@@ -71,6 +71,15 @@ def test_stmt():
tvm.stmt.For.Serial, 0, tvm.stmt.For.Serial, 0,
x) x)
def test_dir():
x = tvm.var('x')
dir(x)
def test_dtype():
x = tvm.var('x')
assert x.dtype == 'int32'
y = tvm.var('y')
assert (x > y).dtype == 'uint1'
if __name__ == "__main__": if __name__ == "__main__":
test_attr() test_attr()
...@@ -81,3 +90,5 @@ if __name__ == "__main__": ...@@ -81,3 +90,5 @@ if __name__ == "__main__":
test_basic() test_basic()
test_stmt() test_stmt()
test_let() test_let()
test_dir()
test_dtype()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment