Commit 401ffe13 by Steven S. Lyubomirsky Committed by Yizhi Liu

[Relay][Op] Add test for batch_flatten (#2134)

* Add tests for batch_flatten and softmax

* Softmax is already tested elsewhere
parent c113712d
......@@ -9,6 +9,7 @@ from ..op import OpPattern, schedule_injective
reg.register_schedule("nn.relu", schedule_injective)
reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
# softmax
@reg.register_schedule("nn.softmax")
def schedule_softmax(_, outputs, target):
"""Schedule definition of softmax"""
......
......@@ -391,6 +391,27 @@ def test_l2_normalize():
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def batch_flatten(data):
shape = data.shape
target_dim = 1
for i in range(len(shape) - 1):
target_dim = target_dim * shape[i + 1]
return np.reshape(data, (shape[0], target_dim))
def test_batch_flatten():
t1 = relay.TensorType((5, 10, 5))
x = relay.Var("x", t1)
func = relay.Function([x], relay.nn.batch_flatten(x))
data = np.random.rand(5, 10, 5).astype(t1.dtype)
ref_res = batch_flatten(data)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
if __name__ == "__main__":
test_pool2d()
test_avg_pool2d_no_count_pad()
......@@ -403,3 +424,4 @@ if __name__ == "__main__":
test_conv2d_transpose_infer_type()
test_conv2d_transpose_run()
test_conv2d_run()
test_batch_flatten()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment