Commit 331abacb by Siju Committed by Tianqi Chen

Testcases of onnx (#2274)

parent a8b34309
......@@ -346,9 +346,9 @@ class ThresholdedRelu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
return _sym.relu(inputs[0] - alpha)
alpha = float(attr.get('alpha', 1.0))
alpha_tensor = _sym.full_like(inputs[0], fill_value=float(alpha))
return _sym.elemwise_mul(inputs[0], _sym.greater(inputs[0], alpha_tensor))
class ImageScaler(OnnxOpConverter):
......
......@@ -10,7 +10,7 @@ import onnx
from model_zoo import super_resolution, squeezenet1_1, lenet, resnet18_1_0
from onnx import helper, TensorProto
def get_tvm_output(graph_def, input_data, target, ctx, output_shape, output_dtype='float32'):
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32'):
""" Generic function to execute and get tvm output"""
sym, params = nnvm.frontend.from_onnx(graph_def)
......@@ -47,12 +47,12 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape, output_dtyp
# get outputs
if isinstance(output_shape, list) and isinstance(output_dtype, list):
tvm_output_list = []
for i, s in enumerate(output_shape):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
for i, _ in enumerate(output_shape):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
tvm_output = m.get_output(0)
return tvm_output.asnumpy()
def get_caffe2_output(model, x, dtype='float32'):
......@@ -273,7 +273,7 @@ def test_slice():
_test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1))
def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
indata = np.random.uniform(size=(2, 4, 5, 6)).astype(dtype)
indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
outdata = outfunc(indata, **npargs)
y = helper.make_node(opname, ['in'], ['out'], **kwargs)
......@@ -858,6 +858,154 @@ def test_split():
verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]],
[[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1)
def test_binary_ops():
in_shape = (1, 2, 3, 3)
dtype = "float32"
out_shape = in_shape
def verify_binary_ops(op, x, y, out_np, broadcast=None):
if broadcast is None:
z = helper.make_node(op, ['in1', 'in2'], ['out'])
else:
z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1)
graph = helper.make_graph([z],
'_test',
inputs = [helper.make_tensor_value_info("in1",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("in2",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, y], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out)
x = np.random.uniform(size=in_shape).astype(dtype)
y = np.random.uniform(size=in_shape).astype(dtype)
z = np.random.uniform(size=(3,)).astype(dtype)
verify_binary_ops("Add",x, y, x + y, broadcast=None)
verify_binary_ops("Add", x, z, x + z, broadcast=True)
verify_binary_ops("Sub", x, y, x - y, broadcast=None)
verify_binary_ops("Sub", x, z, x - z, broadcast=True)
verify_binary_ops("Mul",x, y, x * y, broadcast=None)
verify_binary_ops("Mul", x, z, x * z, broadcast=True)
verify_binary_ops("Div", x, y, x / y, broadcast=None)
verify_binary_ops("Div", x, z, x / z, broadcast=True)
verify_binary_ops("Sum", x, y, x + y, broadcast=None)
def test_single_ops():
in_shape = (1, 2, 3, 3)
dtype = "float32"
out_shape = in_shape
def verify_single_ops(op, x, out_np):
z = helper.make_node(op, ['in1'], ['out'])
graph = helper.make_graph([z],
'_test',
inputs = [helper.make_tensor_value_info("in1",
TensorProto.FLOAT, list(in_shape)),],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out)
x = np.random.uniform(size=in_shape).astype(dtype)
verify_single_ops("Neg",x, -x)
verify_single_ops("Abs",x, np.abs(x))
verify_single_ops("Reciprocal",x, 1/x)
verify_single_ops("Sqrt",x, np.sqrt(x))
verify_single_ops("Relu",x, np.maximum(x, 0))
verify_single_ops("Exp",x, np.exp(x))
verify_single_ops("Log",x, np.log(x))
verify_single_ops("Log",x, np.log(x))
verify_single_ops("Tanh",x, np.tanh(x))
verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)))
verify_single_ops("Softsign",x, x / (1 + np.abs(x)))
verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)))
def test_leaky_relu():
def leaky_relu_x(x, alpha):
return np.where(x >= 0, x, x * alpha)
_test_onnx_op_elementwise((2, 4, 5, 6),
leaky_relu_x,
{'alpha': 0.25},
'float32',
'LeakyRelu',
{'alpha': 0.25})
def test_elu():
def elu_x(x, alpha):
return np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
_test_onnx_op_elementwise((2, 4, 5, 6),
elu_x,
{'alpha': 0.25},
'float32',
'Elu',
{'alpha': 0.25})
def test_selu():
def selu_x(x, alpha, gamma):
return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
_test_onnx_op_elementwise((2, 4, 5, 6),
selu_x,
{'alpha': 0.25, 'gamma': 0.3},
'float32',
'Selu',
{'alpha': 0.25, 'gamma': 0.3})
def test_ThresholdedRelu():
def ThresholdedRelu_x(x, alpha):
out_np = np.clip(x, alpha, np.inf)
out_np[out_np == alpha] = 0
return out_np
_test_onnx_op_elementwise((2, 4, 5, 6),
ThresholdedRelu_x,
{'alpha': 0.25},
'float32',
'ThresholdedRelu',
{'alpha': 0.25})
def test_ScaledTanh():
def ScaledTanh_x(x, alpha, beta):
return alpha * np.tanh(beta * x)
_test_onnx_op_elementwise((2, 4, 5, 6),
ScaledTanh_x,
{'alpha': 0.25, 'beta': 0.3},
'float32',
'ScaledTanh',
{'alpha': 0.25, 'beta': 0.3})
def test_ParametricSoftplus():
def ParametricSoftplus_x(x, alpha, beta):
return alpha * np.log(np.exp(beta * x) + 1)
_test_onnx_op_elementwise((2, 4, 5, 6),
ParametricSoftplus_x,
{'alpha': 0.25, 'beta': 0.3},
'float32',
'ParametricSoftplus',
{'alpha': 0.25, 'beta': 0.3})
def test_Scale():
def Scale_x(x, scale):
return scale * x
_test_onnx_op_elementwise((2, 4, 5, 6),
Scale_x,
{'scale': 0.25},
'float32',
'Scale',
{'scale': 0.25})
def test_LogSoftmax():
_test_onnx_op_elementwise((1, 4),
topi.testing.log_softmax_python,
{},
'float32',
'LogSoftmax',
{'axis': 1})
if __name__ == '__main__':
# verify_super_resolution_example()
# verify_squeezenet1_1()
......@@ -889,3 +1037,13 @@ if __name__ == '__main__':
test_reduce_sum()
test_reduce_mean()
test_split()
test_binary_ops()
test_single_ops()
test_leaky_relu()
test_elu()
test_selu()
test_ThresholdedRelu()
test_ScaledTanh()
test_ParametricSoftplus()
test_Scale()
test_LogSoftmax()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment