Unverified Commit 15d67c11 by Tianqi Chen Committed by GitHub

[RELAY] reorg testcase, make checked_type property, fix constructor error handling (#1850)

parent 3e527669
......@@ -76,6 +76,8 @@ class NodeBase(object):
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
......
......@@ -82,6 +82,8 @@ cdef class NodeBase:
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# avoid error raised during construction.
self.chandle = NULL
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
......
......@@ -9,7 +9,15 @@ from .. import convert
class Expr(NodeBase):
"""The base type for all Relay expressions."""
@property
def checked_type(self):
"""Get the checked type of relay.
Returns
-------
checked_type : relay.Type
The checked type.
"""
ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated"
......
......@@ -3,6 +3,13 @@ import tvm
from tvm import relay
from tvm.expr import *
def test_bad_constructor():
try:
x = relay.ty.TensorType("xx", "xx")
except tvm.TVMError:
pass
# Span
def test_span():
span = relay.Span(None, 1, 1)
......@@ -169,6 +176,7 @@ def test_if():
if __name__ == "__main__":
test_bad_constructor()
test_span()
test_tensor_type()
test_type_param()
......
......@@ -11,7 +11,7 @@ def test_expand_dims_infer_type():
ib.ret(relay.expand_dims(x, axis=2))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, t, 1, 100), "float32")
......@@ -27,7 +27,7 @@ def test_unary_op():
ib.ret(op(x.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((10, 4), "int32")
......
......@@ -16,7 +16,7 @@ def test_conv2d_infer_type():
channels=2))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, 2, 224, 224), "float32")
assert ftype.arg_types[1] == relay.ty.TensorType(
......@@ -31,7 +31,7 @@ def test_conv2d_infer_type():
ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, 2, 222, 222), "int32")
......@@ -50,7 +50,7 @@ def test_conv2d_infer_type():
out_dtype="int32"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, 4, 224, 224, 4, 4), "int32")
assert ftype.arg_types[1] == relay.ty.TensorType(
......
......@@ -9,5 +9,5 @@ def test_unary_identity():
ib.ret(op(x.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32")
......@@ -16,7 +16,7 @@ def test_cmp_type():
ib.ret(op(x.var, y.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")
......@@ -32,7 +32,7 @@ def test_binary_broadcast():
ib.ret(op(x.var, y.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
......
......@@ -12,7 +12,7 @@ from tvm.relay.expr import Function
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type()
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
......@@ -20,7 +20,7 @@ def assert_has_type(expr, typ, env=Environment({})):
def assert_decl_has_type(env, name, typ):
func = env[name]
assert func.checked_type() == typ
assert func.checked_type == typ
def test_monomorphic_let():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment