Unverified Commit 15d67c11 by Tianqi Chen Committed by GitHub

[RELAY] reorg testcase, make checked_type property, fix constructor error handling (#1850)

parent 3e527669
...@@ -76,6 +76,8 @@ class NodeBase(object): ...@@ -76,6 +76,8 @@ class NodeBase(object):
So the return handle is directly set into the Node object So the return handle is directly set into the Node object
instead of creating a new Node. instead of creating a new Node.
""" """
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args) handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, NodeHandle): if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle) handle = NodeHandle(handle)
......
...@@ -82,6 +82,8 @@ cdef class NodeBase: ...@@ -82,6 +82,8 @@ cdef class NodeBase:
So the return handle is directly set into the Node object So the return handle is directly set into the Node object
instead of creating a new Node. instead of creating a new Node.
""" """
# avoid error raised during construction.
self.chandle = NULL
cdef void* chandle cdef void* chandle
ConstructorCall( ConstructorCall(
(<FunctionBase>fconstructor).chandle, (<FunctionBase>fconstructor).chandle,
......
...@@ -9,7 +9,15 @@ from .. import convert ...@@ -9,7 +9,15 @@ from .. import convert
class Expr(NodeBase): class Expr(NodeBase):
"""The base type for all Relay expressions.""" """The base type for all Relay expressions."""
@property
def checked_type(self): def checked_type(self):
"""Get the checked type of relay.
Returns
-------
checked_type : relay.Type
The checked type.
"""
ret = self._checked_type_ ret = self._checked_type_
if ret is None: if ret is None:
raise ValueError("The type checker has not populated" raise ValueError("The type checker has not populated"
......
...@@ -3,6 +3,13 @@ import tvm ...@@ -3,6 +3,13 @@ import tvm
from tvm import relay from tvm import relay
from tvm.expr import * from tvm.expr import *
def test_bad_constructor():
try:
x = relay.ty.TensorType("xx", "xx")
except tvm.TVMError:
pass
# Span # Span
def test_span(): def test_span():
span = relay.Span(None, 1, 1) span = relay.Span(None, 1, 1)
...@@ -169,6 +176,7 @@ def test_if(): ...@@ -169,6 +176,7 @@ def test_if():
if __name__ == "__main__": if __name__ == "__main__":
test_bad_constructor()
test_span() test_span()
test_tensor_type() test_tensor_type()
test_type_param() test_type_param()
......
...@@ -11,7 +11,7 @@ def test_expand_dims_infer_type(): ...@@ -11,7 +11,7 @@ def test_expand_dims_infer_type():
ib.ret(relay.expand_dims(x, axis=2)) ib.ret(relay.expand_dims(x, axis=2))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type() ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType( assert ftype.ret_type == relay.ty.TensorType(
(n, t, 1, 100), "float32") (n, t, 1, 100), "float32")
...@@ -27,7 +27,7 @@ def test_unary_op(): ...@@ -27,7 +27,7 @@ def test_unary_op():
ib.ret(op(x.var)) ib.ret(op(x.var))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type() ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((10, 4), "int32") assert ftype.ret_type == relay.TensorType((10, 4), "int32")
......
...@@ -16,7 +16,7 @@ def test_conv2d_infer_type(): ...@@ -16,7 +16,7 @@ def test_conv2d_infer_type():
channels=2)) channels=2))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type() ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType( assert ftype.ret_type == relay.ty.TensorType(
(n, 2, 224, 224), "float32") (n, 2, 224, 224), "float32")
assert ftype.arg_types[1] == relay.ty.TensorType( assert ftype.arg_types[1] == relay.ty.TensorType(
...@@ -31,7 +31,7 @@ def test_conv2d_infer_type(): ...@@ -31,7 +31,7 @@ def test_conv2d_infer_type():
ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32")) ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32"))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type() ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType( assert ftype.ret_type == relay.ty.TensorType(
(n, 2, 222, 222), "int32") (n, 2, 222, 222), "int32")
...@@ -50,7 +50,7 @@ def test_conv2d_infer_type(): ...@@ -50,7 +50,7 @@ def test_conv2d_infer_type():
out_dtype="int32")) out_dtype="int32"))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type() ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType( assert ftype.ret_type == relay.ty.TensorType(
(1, 4, 224, 224, 4, 4), "int32") (1, 4, 224, 224, 4, 4), "int32")
assert ftype.arg_types[1] == relay.ty.TensorType( assert ftype.arg_types[1] == relay.ty.TensorType(
......
...@@ -9,5 +9,5 @@ def test_unary_identity(): ...@@ -9,5 +9,5 @@ def test_unary_identity():
ib.ret(op(x.var)) ib.ret(op(x.var))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type() ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32") assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32")
...@@ -16,7 +16,7 @@ def test_cmp_type(): ...@@ -16,7 +16,7 @@ def test_cmp_type():
ib.ret(op(x.var, y.var)) ib.ret(op(x.var, y.var))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type() ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1") assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")
...@@ -32,7 +32,7 @@ def test_binary_broadcast(): ...@@ -32,7 +32,7 @@ def test_binary_broadcast():
ib.ret(op(x.var, y.var)) ib.ret(op(x.var, y.var))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type() ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32") assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
......
...@@ -12,7 +12,7 @@ from tvm.relay.expr import Function ...@@ -12,7 +12,7 @@ from tvm.relay.expr import Function
def assert_has_type(expr, typ, env=Environment({})): def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr) checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type() checked_type = checked_expr.checked_type
if checked_type != typ: if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % ( raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ)) checked_type, typ))
...@@ -20,7 +20,7 @@ def assert_has_type(expr, typ, env=Environment({})): ...@@ -20,7 +20,7 @@ def assert_has_type(expr, typ, env=Environment({})):
def assert_decl_has_type(env, name, typ): def assert_decl_has_type(env, name, typ):
func = env[name] func = env[name]
assert func.checked_type() == typ assert func.checked_type == typ
def test_monomorphic_let(): def test_monomorphic_let():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment