Commit 3a1bb8c7 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Serialization round-trip tests (#1968)

parent 975d0d44
...@@ -141,3 +141,25 @@ def alpha_equal(lhs, rhs): ...@@ -141,3 +141,25 @@ def alpha_equal(lhs, rhs):
True iff lhs is alpha equal to rhs. True iff lhs is alpha equal to rhs.
""" """
return bool(_make._alpha_equal(lhs, rhs)) return bool(_make._alpha_equal(lhs, rhs))
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs: tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
...@@ -183,7 +183,7 @@ class AlphaEqualHandler: ...@@ -183,7 +183,7 @@ class AlphaEqualHandler:
bool VisitType_(const TypeRelationNode* lhs, const Type& other) final { bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) { if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
if (!lhs->func.same_as(rhs->func)) return false; if (lhs->func->name != rhs->func->name) return false;
if (lhs->num_inputs != rhs->num_inputs) return false; if (lhs->num_inputs != rhs->num_inputs) return false;
if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false; if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
if (lhs->args.size() != rhs->args.size()) return false; if (lhs->args.size() != rhs->args.size()) return false;
......
...@@ -2,6 +2,14 @@ ...@@ -2,6 +2,14 @@
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.expr import * from tvm.expr import *
from tvm.relay.ir_pass import graph_equal
def check_json_roundtrip(node):
json_str = tvm.save_json(node)
back = tvm.load_json(json_str)
assert graph_equal(back, node)
def test_bad_constructor(): def test_bad_constructor():
try: try:
...@@ -21,6 +29,13 @@ def test_span(): ...@@ -21,6 +29,13 @@ def test_span():
assert isinstance(span, relay.base.Span) assert isinstance(span, relay.base.Span)
str(span) str(span)
# span is not a node so we can't use graph_equal
# to test the round trip
back = tvm.load_json(tvm.save_json(span))
assert back.source == span.source
assert back.lineno == span.lineno
assert back.col_offset == span.col_offset
# Types # Types
def test_tensor_type(): def test_tensor_type():
...@@ -31,6 +46,7 @@ def test_tensor_type(): ...@@ -31,6 +46,7 @@ def test_tensor_type():
assert tt.shape == shape assert tt.shape == shape
assert tt.span == None assert tt.span == None
str(tt) str(tt)
check_json_roundtrip(tt)
def test_type_param(): def test_type_param():
...@@ -38,21 +54,23 @@ def test_type_param(): ...@@ -38,21 +54,23 @@ def test_type_param():
assert tp.kind == relay.Kind.Type assert tp.kind == relay.Kind.Type
# assert tp.span # TODO allow us to set span # assert tp.span # TODO allow us to set span
str(tp) str(tp)
check_json_roundtrip(tp)
def test_func_type(): def test_func_type():
type_params = tvm.convert([]) type_params = tvm.convert([])
type_constraints = tvm.convert([]) # TODO: fill me in type_constraints = tvm.convert([]) # TODO: fill me in
arg_types = tvm.convert([]) arg_types = tvm.convert([])
ret_type = None ret_type = relay.TensorType((1, 2, 3), 'float32')
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params assert tf.type_params == type_params
assert tf.type_constraints == type_constraints assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types assert tf.arg_types == arg_types
assert tf.ret_type == ret_type assert tf.ret_type == ret_type
assert tf.span == None assert tf.span == None
# TODO make sure we can set # TODO make sure we can set span
str(tf) str(tf)
check_json_roundtrip(tf)
def test_tuple_type(): def test_tuple_type():
...@@ -63,13 +81,15 @@ def test_tuple_type(): ...@@ -63,13 +81,15 @@ def test_tuple_type():
tup_ty = relay.TupleType(fields) tup_ty = relay.TupleType(fields)
assert tup_ty.fields == fields assert tup_ty.fields == fields
str(tup_ty)
check_json_roundtrip(tup_ty)
def test_type_relation(): def test_type_relation():
tp = relay.TypeVar('tp', relay.Kind.Type) tp = relay.TypeVar('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
args = tvm.convert([tf, tt, tp]) args = tvm.convert([tp, tf, tt])
num_inputs = 2 num_inputs = 2
func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
...@@ -78,6 +98,8 @@ def test_type_relation(): ...@@ -78,6 +98,8 @@ def test_type_relation():
tr = relay.TypeRelation(func, args, num_inputs, attrs) tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args assert tr.args == args
assert tr.num_inputs == num_inputs assert tr.num_inputs == num_inputs
str(tr)
check_json_roundtrip(tr)
def test_constant(): def test_constant():
...@@ -86,6 +108,7 @@ def test_constant(): ...@@ -86,6 +108,7 @@ def test_constant():
assert const.data == arr assert const.data == arr
assert const.span == None assert const.span == None
str(const) str(const)
check_json_roundtrip(const)
def test_tuple(): def test_tuple():
...@@ -94,6 +117,7 @@ def test_tuple(): ...@@ -94,6 +117,7 @@ def test_tuple():
assert tup.fields == fields assert tup.fields == fields
assert tup.span == None assert tup.span == None
str(tup) str(tup)
check_json_roundtrip(tup)
def test_local_var(): def test_local_var():
...@@ -103,6 +127,7 @@ def test_local_var(): ...@@ -103,6 +127,7 @@ def test_local_var():
assert lv.type_annotation is None assert lv.type_annotation is None
# assert lv.span == None todo(@jroesch): what do we do about spans # assert lv.span == None todo(@jroesch): what do we do about spans
str(lv) str(lv)
check_json_roundtrip(lv)
t1 = relay.ty.TensorType((), "float") t1 = relay.ty.TensorType((), "float")
lv = relay.Var(name_hint, t1) lv = relay.Var(name_hint, t1)
...@@ -116,20 +141,22 @@ def test_global_var(): ...@@ -116,20 +141,22 @@ def test_global_var():
gv.name_hint == name_hint gv.name_hint == name_hint
# assert lv.span == None todo(@jroesch): what do we do about spans # assert lv.span == None todo(@jroesch): what do we do about spans
str(gv) str(gv)
check_json_roundtrip(gv)
def test_function(): def test_function():
param_names = ['a', 'b', 'c', 'd'] param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names]) params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None ret_type = relay.TupleType(tvm.convert([]))
body = None body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([]) type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params) fn = relay.Function(params, body, ret_type, type_params)
assert fn.params == params assert fn.params == params
assert fn.body == body assert fn.body == body
assert fn.type_params == type_params assert fn.type_params == type_params
assert fn.span == None assert fn.span == None
str(fn) str(fn)
check_json_roundtrip(fn)
def test_call(): def test_call():
...@@ -141,6 +168,7 @@ def test_call(): ...@@ -141,6 +168,7 @@ def test_call():
assert call.args == args assert call.args == args
assert call.span == None assert call.span == None
str(call) str(call)
check_json_roundtrip(call)
def test_let(): def test_let():
...@@ -156,6 +184,7 @@ def test_let(): ...@@ -156,6 +184,7 @@ def test_let():
assert let.body == lv assert let.body == lv
assert let.span == None assert let.span == None
str(let) str(let)
check_json_roundtrip(let)
def test_if(): def test_if():
...@@ -168,6 +197,7 @@ def test_if(): ...@@ -168,6 +197,7 @@ def test_if():
assert ife.false_branch == right assert ife.false_branch == right
assert ife.span == None assert ife.span == None
str(ife) str(ife)
check_json_roundtrip(ife)
def test_tuple_get_item(): def test_tuple_get_item():
...@@ -176,6 +206,8 @@ def test_tuple_get_item(): ...@@ -176,6 +206,8 @@ def test_tuple_get_item():
assert get.tuple_value == tup assert get.tuple_value == tup
assert get.index == 1 assert get.index == 1
str(get) str(get)
check_json_roundtrip(get)
if __name__ == "__main__": if __name__ == "__main__":
test_bad_constructor() test_bad_constructor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment