Commit 3a1bb8c7 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Serialization round-trip tests (#1968)

parent 975d0d44
......@@ -141,3 +141,25 @@ def alpha_equal(lhs, rhs):
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs: tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
......@@ -183,7 +183,7 @@ class AlphaEqualHandler:
bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
if (!lhs->func.same_as(rhs->func)) return false;
if (lhs->func->name != rhs->func->name) return false;
if (lhs->num_inputs != rhs->num_inputs) return false;
if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
......
......@@ -2,6 +2,14 @@
import tvm
from tvm import relay
from tvm.expr import *
from tvm.relay.ir_pass import graph_equal
def check_json_roundtrip(node):
json_str = tvm.save_json(node)
back = tvm.load_json(json_str)
assert graph_equal(back, node)
def test_bad_constructor():
try:
......@@ -21,6 +29,13 @@ def test_span():
assert isinstance(span, relay.base.Span)
str(span)
# span is not a node so we can't use graph_equal
# to test the round trip
back = tvm.load_json(tvm.save_json(span))
assert back.source == span.source
assert back.lineno == span.lineno
assert back.col_offset == span.col_offset
# Types
def test_tensor_type():
......@@ -31,6 +46,7 @@ def test_tensor_type():
assert tt.shape == shape
assert tt.span == None
str(tt)
check_json_roundtrip(tt)
def test_type_param():
......@@ -38,21 +54,23 @@ def test_type_param():
assert tp.kind == relay.Kind.Type
# assert tp.span # TODO allow us to set span
str(tp)
check_json_roundtrip(tp)
def test_func_type():
type_params = tvm.convert([])
type_constraints = tvm.convert([]) # TODO: fill me in
arg_types = tvm.convert([])
ret_type = None
ret_type = relay.TensorType((1, 2, 3), 'float32')
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params
assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types
assert tf.ret_type == ret_type
assert tf.span == None
# TODO make sure we can set
# TODO make sure we can set span
str(tf)
check_json_roundtrip(tf)
def test_tuple_type():
......@@ -63,13 +81,15 @@ def test_tuple_type():
tup_ty = relay.TupleType(fields)
assert tup_ty.fields == fields
str(tup_ty)
check_json_roundtrip(tup_ty)
def test_type_relation():
tp = relay.TypeVar('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
args = tvm.convert([tf, tt, tp])
args = tvm.convert([tp, tf, tt])
num_inputs = 2
func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
......@@ -78,6 +98,8 @@ def test_type_relation():
tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
str(tr)
check_json_roundtrip(tr)
def test_constant():
......@@ -86,6 +108,7 @@ def test_constant():
assert const.data == arr
assert const.span == None
str(const)
check_json_roundtrip(const)
def test_tuple():
......@@ -94,6 +117,7 @@ def test_tuple():
assert tup.fields == fields
assert tup.span == None
str(tup)
check_json_roundtrip(tup)
def test_local_var():
......@@ -103,6 +127,7 @@ def test_local_var():
assert lv.type_annotation is None
# assert lv.span == None todo(@jroesch): what do we do about spans
str(lv)
check_json_roundtrip(lv)
t1 = relay.ty.TensorType((), "float")
lv = relay.Var(name_hint, t1)
......@@ -116,20 +141,22 @@ def test_global_var():
gv.name_hint == name_hint
# assert lv.span == None todo(@jroesch): what do we do about spans
str(gv)
check_json_roundtrip(gv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = None
ret_type = relay.TupleType(tvm.convert([]))
body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
fn = relay.Function(params, body, ret_type, type_params)
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
str(fn)
check_json_roundtrip(fn)
def test_call():
......@@ -141,6 +168,7 @@ def test_call():
assert call.args == args
assert call.span == None
str(call)
check_json_roundtrip(call)
def test_let():
......@@ -156,6 +184,7 @@ def test_let():
assert let.body == lv
assert let.span == None
str(let)
check_json_roundtrip(let)
def test_if():
......@@ -168,6 +197,7 @@ def test_if():
assert ife.false_branch == right
assert ife.span == None
str(ife)
check_json_roundtrip(ife)
def test_tuple_get_item():
......@@ -176,6 +206,8 @@ def test_tuple_get_item():
assert get.tuple_value == tup
assert get.index == 1
str(get)
check_json_roundtrip(get)
if __name__ == "__main__":
test_bad_constructor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment