Commit 8eb4519a by MORITA Kazutaka Committed by Tianqi Chen

[TEST][KERAS] convert tvm output to channels_last format (#1733)

parent 27b6812b
......@@ -21,15 +21,6 @@ def verify_keras_frontend(keras_model, need_transpose=True):
for layer in keras_model._input_layers:
in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
#keras_model._output_coordinates contains the output_node, node_index and tensor_index
#get the outshapes from combining output node and tensor index
out_shapes = []
for layer, node_index, tensor_index in keras_model._output_coordinates:
layer_out = layer.output
if isinstance(layer.output, list):#if multiple outputs are there
layer_out = layer.output[tensor_index]
out_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer_out.shape))
def get_keras_output(xs, dtype='float32'):
return keras_model.predict(xs)
......@@ -44,20 +35,24 @@ def verify_keras_frontend(keras_model, need_transpose=True):
m.set_input(**params)
m.run()
out = [m.get_output(i).asnumpy()
for i, shape in enumerate(out_shapes)]
return out if len(out) > 1 else out[0]
return [m.get_output(i).asnumpy() for i in range(m.get_num_outputs())]
def to_channels_first(arr):
return arr.transpose([0, -1] + list(range(1, arr.ndim - 1)))
def to_channels_last(arr):
return arr.transpose([0] + list(range(2, arr.ndim)) + [1])
xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
keras_out = get_keras_output(xs)
keras_out = keras_out if isinstance(keras_out, list) else [keras_out]
for target, ctx in ctx_list():
tvm_out = get_tvm_output([x.transpose([0,3,1,2]) for x in xs ] if need_transpose else xs, target, ctx)
if isinstance (keras_out, list):
for kout, tout in zip(keras_out, tvm_out):
np.testing.assert_allclose(kout, tout.reshape(kout.shape), rtol=1e-5, atol=1e-5)
else:
np.testing.assert_allclose(keras_out, tvm_out.reshape(keras_out.shape), rtol=1e-5, atol=1e-5)
tvm_out = get_tvm_output([to_channels_first(x) for x in xs] if need_transpose else xs, target, ctx)
for kout, tout in zip(keras_out, tvm_out):
if need_transpose:
tout = to_channels_last(tout)
np.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5)
def test_forward_elemwise_add():
r = []
......@@ -111,7 +106,6 @@ def test_forward_conv():
keras.layers.SeparableConv2D(filters=10, kernel_size=(3,3), padding='same')]
for conv_func in conv_funcs:
x = conv_func(data)
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
......@@ -119,7 +113,6 @@ def test_forward_conv():
def test_forward_upsample():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.UpSampling2D(size=(3,3))(data)
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
......@@ -127,7 +120,6 @@ def test_forward_upsample():
def test_forward_reshape():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Reshape(target_shape=(32,32,3))(data)
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
......@@ -141,7 +133,6 @@ def test_forward_crop():
x = keras.layers.Cropping2D(cropping=(1, 0))(x)
x = keras.layers.Cropping2D(cropping=0)(x)
x = keras.layers.Add()([x, x])
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
......@@ -189,7 +180,6 @@ def test_forward_activations():
keras.layers.Activation('linear')]
for act_func in act_funcs:
x = act_func(data)
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment