Unverified Commit 9b274cbb by Tianqi Chen Committed by GitHub

[PYTHON] Make IntImm more like an integer (#5232)

parent 7de8a539
...@@ -439,6 +439,7 @@ class FloatImm(ConstExpr): ...@@ -439,6 +439,7 @@ class FloatImm(ConstExpr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
tvm.ir._ffi_api.FloatImm, dtype, value) tvm.ir._ffi_api.FloatImm, dtype, value)
@tvm._ffi.register_object @tvm._ffi.register_object
class IntImm(ConstExpr): class IntImm(ConstExpr):
"""Int constant. """Int constant.
...@@ -455,9 +456,24 @@ class IntImm(ConstExpr): ...@@ -455,9 +456,24 @@ class IntImm(ConstExpr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
tvm.ir._ffi_api.IntImm, dtype, value) tvm.ir._ffi_api.IntImm, dtype, value)
def __hash__(self):
return self.value
def __int__(self): def __int__(self):
return self.value return self.value
def __nonzero__(self):
return self.value != 0
def __eq__(self, other):
return _ffi_api._OpEQ(self, other)
def __ne__(self, other):
return _ffi_api._OpNE(self, other)
def __bool__(self):
return self.__nonzero__()
@tvm._ffi.register_object @tvm._ffi.register_object
class StringImm(ConstExpr): class StringImm(ConstExpr):
......
...@@ -302,7 +302,21 @@ def test_buffer_load_store(): ...@@ -302,7 +302,21 @@ def test_buffer_load_store():
assert isinstance(s, tvm.tir.BufferStore) assert isinstance(s, tvm.tir.BufferStore)
def test_intimm_cond():
x = tvm.runtime.convert(1)
y = tvm.runtime.convert(1)
s = {x}
assert y in s
assert x == y
assert x < 20
assert not (x >= 20)
assert x < 10 and y < 10
assert not tvm.runtime.convert(x != 1)
assert x == 1
if __name__ == "__main__": if __name__ == "__main__":
test_intimm_cond()
test_buffer_load_store() test_buffer_load_store()
test_vars() test_vars()
test_prim_func() test_prim_func()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment