Commit 156aa590 by Josh Fromm Committed by Jared Roesch

[Relay][Frontend][ONNX] New Operators and Opsets to Support BERT (#4197)

* Added slice v10

* Added constantofshape operation and small refactor.

* Finished one_hot implementation.

* Reshape working across all bert layers.

* Fixed constantofshape and removed code duplication.

* onnx model fully ingested.

* Working on improving onnx tests.

* Changed onnx testing to use onnxruntime instead of caffe2, also formatted.

* Add arbitrary output nodes to onnx frontend.

* Added v6 tiling for bert squad 8 support.

* Small syntax fixes

* Reduced code duplication in split opset versions.

* Added batch matmul test

* Added unstack split testing.

* Adde onehot test, needs a little cleanup probably.

* Replaced deprecated constant fill with constantofshape and updated tests accordingly.

* Added tests for new opset version of slice and tile.

* lint clean up

* Lint fixes

* Changed onnx dependency

* Went back to caffe2 runtime for CI integration.

* Rebase and small typo/syntax changes.

* Added hard casting of onehot attributes to int.
parent 71f39be5
......@@ -19,11 +19,13 @@ from __future__ import absolute_import as _abs
import logging
import tvm
import numpy as np
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import module as _module
from .. import transform as _transform
from .. import op as _op
from .. import analysis
class RequiredAttr(object):
......@@ -474,6 +476,50 @@ def infer_channels(inputs, transpose=False):
return channels
def infer_value(input_val, params):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
def infer_value_simulated(input_val, params):
"""Extention to infer_value that can be used when some input
values are missing. This function creates dummy inputs with the same
shape and random values then calls infer_value. This is helpful when
implementing certain onnx operators where we need to evaluate the graph
to determine a static shape.
"""
fake_params = []
# Add a fake copy of all missing params.
for free_param in analysis.free_vars(input_val):
if free_param.name_hint not in params:
fp_dtype = free_param.type_annotation.dtype
fp_shape = [s.value for s in free_param.type_annotation.shape]
fake_params.append(free_param)
params[free_param.name_hint] = tvm.nd.array(
np.random.rand(*fp_shape).astype(fp_dtype)
)
# Now infer the value.
output_value = infer_value(input_val, params)
# Clean fake params out of param dictionary.
for fake_p in fake_params:
params.pop(fake_p.name_hint, None)
return output_value
def new_var(name_hint,
type_annotation=None,
shape=None,
......
......@@ -18,20 +18,30 @@
"""ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs
import logging
import numpy as np
import tvm
from ... import nd as _nd
from .. import analysis
from .. import transform as _transform
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, infer_value, infer_value_simulated, get_name
__all__ = ['from_onnx']
def get_numpy(tensor_proto):
"""Grab data in TensorProto and convert to numpy array."""
try:
from onnx.numpy_helper import to_array
except ImportError as e:
raise ImportError(
"Unable to import onnx which is required {}".format(e))
return to_array(tensor_proto)
def dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
......@@ -43,6 +53,7 @@ def dimension_picker(prefix, surfix=''):
return _impl
def revert_caffe2_pad(pads):
"""Caffe2 requires two times the normal padding."""
if len(pads) == 4:
......@@ -279,6 +290,21 @@ class MatMul(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
# Need to check input shape as batch matmul must be supported.
a_shape = infer_shape(inputs[0])
# When performing a batch matmul, we need to properly handle N-dim shapes.
if len(a_shape) > 2:
b_shape = infer_shape(inputs[1])
# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
# Reshape output to original dimensions.
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)
......@@ -426,35 +452,18 @@ class Reshape(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'shape' in attr:
return _op.reshape(inputs[0], attr['shape'])
@classmethod
def _impl_v5(cls, inputs, attr, params):
if get_name(inputs[1]) in params:
shape = tuple(params[inputs[1].name_hint].asnumpy())
out = _op.reshape(inputs[0], shape)
else:
data, shape = inputs
logging.warning("Constant evaluating Reshape's shape argument, may reduce performance")
shape_params = analysis.free_vars(shape)
func = _expr.Function(shape_params, shape)
mod = _module.Module.from_expr(func)
seq = _transform.Sequential([_transform.InferType(),
_transform.FoldConstant(),
_transform.FuseOps(0),
_transform.InferType()])
with tvm.relay.PassContext(opt_level=2):
mod = seq(mod)
with tvm.relay.build_config(opt_level=0):
ex = tvm.relay.create_executor("debug", mod=mod)
inputs = []
for sp in shape_params:
if not sp.name_hint in params:
sh = [int(i) for i in sp.type_annotation.shape]
inputs.append(
tvm.nd.array(np.random.rand(*sh).astype('float32')))
static_shape = ex.evaluate()(*inputs, **params)
out = _op.reshape(data, newshape=tuple(static_shape.asnumpy()))
static_shape = infer_value_simulated(shape, params)
out = _op.reshape(data, newshape=tuple(
static_shape.asnumpy().astype('int32')))
return out
class Concat(OnnxOpConverter):
......@@ -640,11 +649,17 @@ class Split(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
splits = attr.get('split', False)
if splits:
attr['indices_or_sections'] = []
index = 0
for i in attr['split'][:-1]:
for i in splits[:-1]:
index += i
attr['indices_or_sections'].append(index)
# When splits isnt specified divide evenly over axis.
else:
in_shape = infer_shape(inputs[0])
attr['indices_or_sections'] = in_shape[attr['axis']]
return AttrCvt(
'split',
ignores=['split'])(inputs, attr, params)
......@@ -653,32 +668,38 @@ class Split(OnnxOpConverter):
class Slice(OnnxOpConverter):
""" Operator converter for Slice.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if isinstance(attr['starts'], int):
attr['starts'] = (attr['starts'],)
attr['ends'] = (attr['ends'],)
try:
# Update the starts and ends according to axes if required.
if isinstance(attr['axes'], int):
attr['axes'] = (attr['axes'],)
if (max(attr['axes']) + 1) != len(attr['axes']):
@classmethod
def _common(cls, starts, ends, axes):
new_axes = []
new_starts = []
new_ends = []
pop_index = 0
for i in range(max(attr['axes']) + 1):
if i in attr['axes']:
for i in range(max(axes) + 1):
if i in axes:
new_axes.append(i)
new_starts.append(attr['starts'][pop_index])
new_ends.append(attr['ends'][pop_index])
new_starts.append(starts[pop_index])
new_ends.append(ends[pop_index])
pop_index += 1
else:
new_axes.append(i)
new_starts.append(0)
new_ends.append(np.iinfo(np.int32).max)
return new_starts, new_ends, new_axes
@classmethod
def _impl_v1(cls, inputs, attr, params):
if isinstance(attr['starts'], int):
attr['starts'] = (attr['starts'],)
attr['ends'] = (attr['ends'],)
try:
# Update the starts and ends according to axes if required.
if isinstance(attr['axes'], int):
attr['axes'] = (attr['axes'],)
if (max(attr['axes']) + 1) != len(attr['axes']):
new_starts, new_ends, new_axes = cls._common(
attr['starts'], attr['ends'], attr['axes'])
attr['axes'] = new_axes
attr['starts'] = new_starts
attr['ends'] = new_ends
......@@ -690,6 +711,23 @@ class Slice(OnnxOpConverter):
'ends': 'end'},
ignores=['axes'])(inputs, attr)
@classmethod
def _impl_v10(cls, inputs, attr, params):
starts = params[get_name(inputs[1])].asnumpy()
ends = params[get_name(inputs[2])].asnumpy()
# Update the starts and ends according to axes if required.
if len(inputs) >= 4:
axes = params[get_name(inputs[3])].asnumpy()
if max(axes + 1) != len(axes):
new_starts, new_ends, _ = cls._common(
starts, ends, axes)
starts = new_starts
ends = new_ends
return _op.strided_slice(inputs[0], begin=starts, end=ends)
class Gather(OnnxOpConverter):
""" Operator converter for Gather.
"""
......@@ -698,7 +736,6 @@ class Gather(OnnxOpConverter):
axis = attr.get('axis', 0)
return AttrCvt('take',
extras={'axis':axis})(inputs, {})
#return _op.take(inputs[0], inputs[1], axis)
class Greater(OnnxOpConverter):
......@@ -848,33 +885,49 @@ class Softmax(OnnxOpConverter):
attr['axis'] = 1
return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params)
class ConstantFill(OnnxOpConverter):
""" Operator converter for ConstantFill.
class OneHot(OnnxOpConverter):
""" Operator converter for OneHot.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
num_inputs = len(inputs)
if 'shape' in attr:
if num_inputs > 1:
raise ImportError(
"Can't set shape and input tensor at a time")
shape = attr.pop('shape')
else:
if num_inputs == 1:
raise ImportError(
"Either shape attribute or input should be set")
if 'input_as_shape' in attr and attr['input_as_shape']:
shape = params[get_name(inputs[0])].asnumpy()
def _impl_v9(cls, inputs, attr, params):
# Extract relay one_hot inputs.
indices, depth, values = inputs
# Split onnx on off values into two separate expressions.
off_value, on_value = _op.take(
values, _op.const(0)), _op.take(values, _op.const(1))
# Extract the datatype of the output from on_value.
dtype = infer_type(on_value).checked_type.dtype
# Convert depth into an integer.
depth = int(infer_value(depth, params).asnumpy()[0])
# set default value when axis is not set in the model
if 'axis' not in attr:
attr['axis'] = -1
return _op.one_hot(indices,
on_value,
off_value,
depth,
int(attr['axis']),
dtype=dtype)
class ConstantOfShape(OnnxOpConverter):
""" Operator converter for ConstantOfShape.
"""
@classmethod
def _impl_v9(cls, inputs, attr, params):
if 'value' in attr:
np_value = get_numpy(attr.pop('value'))[0]
value = _expr.const(np_value)
dtype = np_value.dtype.name
else:
if 'extra_shape' in attr:
raise tvm.error.OpAttributeInvalid('Attribute "extra_shape" not '
'supported with "fill_like" for '
'operator ConstantFill.')
return _op.full_like(inputs[0], inputs[1])
value = _expr.const(0)
dtype = 'float32'
static_shape = infer_value_simulated(inputs[0], params)
output = _op.full(
value, shape=tuple(static_shape.asnumpy().astype('int32')), dtype=dtype)
return output
if 'extra_shape' in attr:
shape = shape + attr.pop('extra_shape')
return _op.full(inputs[0], shape)
class Sign(OnnxOpConverter):
""" Operator converter for Sign.
......@@ -916,6 +969,12 @@ class Tile(Elemwise):
reps = attr.pop('repeats') # The number of times repeating the tensor data.
return _op.tile(inputs[0], reps)
@classmethod
def _impl_v6(cls, inputs, attr, params):
reps = tuple(infer_value_simulated(
inputs[1], params).asnumpy().astype('int32'))
return _op.tile(inputs[0], reps)
class Erf(OnnxOpConverter):
"""Operator converter for Erf
"""
......@@ -948,7 +1007,7 @@ def _get_convert_map(opset):
'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
'ScaledTanh': ScaledTanh.get_converter(opset),
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
'ConstantFill': ConstantFill.get_converter(opset),
'ConstantOfShape': ConstantOfShape.get_converter(opset),
# 'GivenTensorFill'
'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
'Scale': Scale.get_converter(opset),
......@@ -958,7 +1017,7 @@ def _get_convert_map(opset):
# 'MeanVarianceNormalization'
# 'Crop'
# 'Embedding'
'Upsample' : Upsample.get_converter(opset),
'Upsample': Upsample.get_converter(opset),
'SpatialBN': BatchNorm.get_converter(opset),
# defs/generator
......@@ -1002,6 +1061,7 @@ def _get_convert_map(opset):
# softmax default axis is different in onnx
'Softmax': Softmax.get_converter(opset),
'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
'OneHot': OneHot.get_converter(opset),
# 'Hardmax'
'Softsign': Softsign.get_converter(opset),
'SoftPlus': SoftPlus.get_converter(opset),
......@@ -1164,14 +1224,6 @@ class GraphProto(object):
shape=list(t_proto.dims),
dtype=array.dtype)
else:
if op_name == "ConstantFill":
fill_value = attr.get('value', 0.0)
dtype = attr.get('dtype', b'int32').decode("utf-8")
i_name = node.output[0]
self._params[i_name] = fill_value
self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype)
inputs.append(self._nodes[i_name])
i_name = self._parse_value_proto(node)
attr['tvm_custom'] = {}
attr['tvm_custom']['name'] = i_name
......@@ -1214,13 +1266,7 @@ class GraphProto(object):
return dtype
def _parse_array(self, tensor_proto):
"""Grab data in TensorProto and convert to numpy array."""
try:
from onnx.numpy_helper import to_array
except ImportError as e:
raise ImportError(
"Unable to import onnx which is required {}".format(e))
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims))
return _nd.array(np_array)
def _parse_attr(self, attr_proto):
......@@ -1301,7 +1347,8 @@ class GraphProto(object):
def from_onnx(model,
shape=None,
dtype="float32"):
dtype="float32",
opset=None):
"""Convert a ONNX model into an equivalent Relay Function.
ONNX graphs are represented as Python Protobuf objects.
......@@ -1322,6 +1369,10 @@ def from_onnx(model,
dtype : str or dict of str to str
The input types to the graph
opset : int, optional
Override to autodetected opset.
This can be helpful for some testing.
Returns
-------
mod : tvm.relay.Module
......@@ -1344,6 +1395,7 @@ def from_onnx(model,
pass
g = GraphProto(shape, dtype)
graph = model.graph
if opset is None:
try:
opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
......
......@@ -39,22 +39,10 @@ from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_channels as _infer_channels
from .common import infer_value as _infer_value
__all__ = ['from_tensorflow']
def _infer_value(input_val, params):
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
def _get_pad_pair(input1d, kernel1d, stride1d):
if input1d % stride1d == 0:
......
......@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import attr
import numpy as np
import math
import torch
......@@ -26,11 +25,11 @@ from tvm import relay
from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list
import onnx
from onnx import helper, TensorProto
import unittest
from onnx import helper, TensorProto, mapping
import scipy
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32'):
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
""" Generic function to execute and get tvm output"""
target = 'llvm'
if isinstance(input_data, list):
......@@ -46,21 +45,22 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype}
mod, params = relay.frontend.from_onnx(graph_def, shape_dict)
mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
with relay.build_config(opt_level=1):
graph, lib, params = relay.build(mod,
target,
params=params)
ctx = tvm.cpu(0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
if isinstance(input_data, list):
for i, e in enumerate(input_names):
m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
m.set_input(input_names[i], tvm.nd.array(
input_data[i].astype(input_data[i].dtype)))
else:
m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(input_names, tvm.nd.array(
input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
......@@ -76,6 +76,7 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
tvm_output = m.get_output(0)
return tvm_output.asnumpy()
def get_caffe2_output(model, x, dtype='float32'):
import caffe2.python.onnx.backend
prepared_backend = caffe2.python.onnx.backend.prepare(model)
......@@ -93,15 +94,20 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def verify_super_resolution_example():
verify_onnx_forward_impl(super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
verify_onnx_forward_impl(
super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
def verify_squeezenet1_1():
verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000))
def verify_lenet():
verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))
def verify_resnet18():
verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))
......@@ -114,17 +120,17 @@ def test_reshape():
ref_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['ref_in'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.INT32,
dims = ref_array.shape,
vals = ref_array.flatten().astype(int)))
value=onnx.helper.make_tensor(name='const_tensor',
data_type=onnx.TensorProto.INT32,
dims=ref_array.shape,
vals=ref_array.flatten().astype(int)))
reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
graph = helper.make_graph([ref_node, reshape_node],
"reshape_test",
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='reshape_test')
......@@ -135,6 +141,7 @@ def test_reshape():
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
def test_shape():
in_shape = (4, 3, 3, 4)
ref_shape = (6, 2, 4, 3)
......@@ -143,19 +150,19 @@ def test_shape():
ref_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['ref_in'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.INT32,
dims = ref_array.shape,
vals = ref_array.flatten().astype(int)))
value=onnx.helper.make_tensor(name='const_tensor',
data_type=onnx.TensorProto.INT32,
dims=ref_array.shape,
vals=ref_array.flatten().astype(int)))
reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
shape_node = helper.make_node("Shape", ['out'], ['final_out'])
graph = helper.make_graph([ref_node, reshape_node, shape_node],
"shape_test",
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("final_out",
outputs=[helper.make_tensor_value_info("final_out",
TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='shape_test')
......@@ -166,6 +173,7 @@ def test_shape():
tvm.testing.assert_allclose(ref_shape, tvm_out)
def _test_power_iteration(x_shape, y_shape):
if isinstance(y_shape, int):
y_shape = [y_shape]
......@@ -179,11 +187,11 @@ def _test_power_iteration(x_shape, y_shape):
graph = helper.make_graph([res],
'power_test',
inputs = [helper.make_tensor_value_info("x",
inputs=[helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(x_shape)),
helper.make_tensor_value_info("y",
TensorProto.FLOAT, list(y_shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(np_res.shape))])
model = helper.make_model(graph, producer_name='power_test')
......@@ -192,11 +200,13 @@ def _test_power_iteration(x_shape, y_shape):
tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape)
tvm.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5)
def test_power():
_test_power_iteration((1, 3), (1))
_test_power_iteration((2, 3), (2, 3))
_test_power_iteration((2, 3), (1, 3))
def test_squeeze():
in_shape = (1, 3, 1, 3, 1, 1)
out_shape = (3, 3)
......@@ -204,9 +214,9 @@ def test_squeeze():
graph = helper.make_graph([y],
'squeeze_test',
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='squeeze_test')
......@@ -217,19 +227,20 @@ def test_squeeze():
tvm.testing.assert_allclose(out_shape, tvm_out.shape)
def test_flatten():
in_shape = (1, 3, 4, 4)
axis = 1
ref_shape = (1, 48)
flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis = axis)
flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis=axis)
graph = helper.make_graph([flatten_node],
"flatten_test",
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='flatten_test')
......@@ -240,6 +251,7 @@ def test_flatten():
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
def test_unsqueeze():
in_shape = (3, 3)
axis = (0, 3, 4)
......@@ -248,9 +260,9 @@ def test_unsqueeze():
graph = helper.make_graph([y],
'squeeze_test',
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='squeeze_test')
......@@ -261,6 +273,7 @@ def test_unsqueeze():
tvm.testing.assert_allclose(out_shape, tvm_out.shape)
def verify_gather(in_shape, indices, axis, dtype):
x = np.random.uniform(size=in_shape).astype(dtype)
indices = np.array(indices, dtype="int32")
......@@ -270,52 +283,123 @@ def verify_gather(in_shape, indices, axis, dtype):
graph = helper.make_graph([y],
'gather_test',
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("indices",
TensorProto.INT32, list(indices.shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='gather_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape)
tvm_out = get_tvm_output(
model, [x, indices], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out)
def test_gather():
verify_gather((4,), [1], 0, 'int32')
verify_gather((1,4), [0], 0, 'int32')
verify_gather((4,), [[[1,0],[0,1]]], 0, 'float32')
verify_gather((2,2), [[[1,0],[0,1]]], 1, 'int32')
verify_gather((3,3,3), [[[1,0]]], -1, 'int32')
verify_gather((4,3,5,6), [[2,1,0,0]], 0, 'float32')
verify_gather((1, 4), [0], 0, 'int32')
verify_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
verify_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
verify_gather((3, 3, 3), [[[1, 0]]], -1, 'int32')
verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
def _test_slice_iteration(indata, outdata, starts, ends, axes=None):
def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None):
if axes:
y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
y = helper.make_node(
"Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
else:
y = helper.make_node("Slice", ['in'], ['out'], starts=starts, ends=ends)
y = helper.make_node(
"Slice", ['in'], ['out'], starts=starts, ends=ends)
graph = helper.make_graph([y],
'slice_test',
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='slice_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
tvm_out = get_tvm_output(
model, indata, target, ctx, outdata.shape, 'float32', opset=1)
tvm.testing.assert_allclose(outdata, tvm_out)
def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None):
if isinstance(starts, int):
starts = (starts, )
if isinstance(ends, int):
ends = (ends, )
if isinstance(axes, int):
axes = (axes, )
starts = np.asarray(starts)
ends = np.asarray(ends)
inputs = [
helper.make_tensor_value_info("data", TensorProto.FLOAT,
list(indata.shape)),
helper.make_tensor_value_info("starts", TensorProto.INT32,
list(starts.shape)),
helper.make_tensor_value_info("ends", TensorProto.INT32,
list(ends.shape))
]
initializer = [
helper.make_tensor("starts", TensorProto.INT32, list(starts.shape),
starts),
helper.make_tensor("ends", TensorProto.INT32, list(ends.shape), ends)
]
if axes:
axes = np.asarray(axes)
y = helper.make_node("Slice", ["data", "starts", "ends", "axes"],
["out"])
inputs.append(
helper.make_tensor_value_info("axes", TensorProto.INT32,
list(axes.shape)))
initializer.append(
helper.make_tensor("axes", TensorProto.INT32, list(axes.shape),
axes))
else:
y = helper.make_node("Slice", ["data", "starts", "ends"], ["out"])
graph = helper.make_graph([y],
'slice_test',
inputs=inputs,
outputs=[
helper.make_tensor_value_info(
"out", TensorProto.FLOAT,
list(outdata.shape))
],
initializer=initializer)
model = helper.make_model(graph, producer_name='slice_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model,
indata,
target,
ctx,
outdata.shape,
'float32',
opset=10)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_slice():
x = np.random.randn(20, 10, 5).astype(np.float32)
_test_slice_iteration(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1))
_test_slice_iteration_v1(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration_v1(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration_v1(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration_v1(x, x[:, 0:-1], (0), (-1), (1))
_test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration_v10(x, x[:, 0:-1], (0), (-1), (1))
def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
......@@ -325,24 +409,29 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
graph = helper.make_graph([y],
opname+'_test',
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name=opname+'_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
tvm_out = get_tvm_output(
model, indata, target, ctx, outdata.shape, dtype)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_floor():
_test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {})
_test_onnx_op_elementwise((2, 4, 5, 6), np.floor,
{}, 'float32', 'Floor', {})
def test_ceil():
_test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {})
def test_clip():
_test_onnx_op_elementwise((2, 4, 5, 6),
np.clip,
......@@ -351,6 +440,38 @@ def test_clip():
'Clip',
{'min': -1.0, 'max': 1.0})
def test_onehot():
indices_shape = [10]
indices_array = np.random.randint(
low=0, high=9, size=indices_shape, dtype='int32')
depth = 10
values = np.asarray([0, 1])
out_np = np.eye(depth)[indices_array.reshape(-1)]
onehot_node = helper.make_node(
"OneHot", ["indices", "depth", "values"], ["out"])
graph = helper.make_graph([onehot_node],
"onehot_test",
inputs=[helper.make_tensor_value_info("indices",
TensorProto.INT32, indices_shape),
helper.make_tensor_value_info("depth",
TensorProto.INT32, [1]),
helper.make_tensor_value_info("values",
TensorProto.INT32, values.shape)],
initializer=[helper.make_tensor("depth", TensorProto.INT32, [1], [depth]),
helper.make_tensor("values", TensorProto.INT32, values.shape, values)],
outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)])
model = helper.make_model(graph, producer_name="onehot_test")
for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, [indices_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_matmul():
a_shape = (4, 3)
b_shape = (3, 4)
......@@ -363,42 +484,73 @@ def test_matmul():
graph = helper.make_graph([mul_node],
"matmul_test",
inputs = [helper.make_tensor_value_info("a",
inputs=[helper.make_tensor_value_info("a",
TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b",
TensorProto.FLOAT, list(b_shape))],
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='matmul_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_batch_matmul():
a_shape = (2, 3, 4, 3)
b_shape = (2, 3, 3, 4)
a_array = np.random.uniform(size=a_shape).astype('float32')
b_array = np.random.uniform(size=b_shape).astype('float32')
out_np = np.matmul(a_array, b_array)
mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
graph = helper.make_graph([mul_node],
"matmul_test",
inputs=[helper.make_tensor_value_info("a",
TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b",
TensorProto.FLOAT, list(b_shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='matmul_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape)
tvm_out = get_tvm_output(
model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
in_array = np.random.uniform(size=shape).astype(dtype)
if alpha == None and beta == None and bias==None:
if alpha == None and beta == None and bias == None:
alpha = 0.0001
beta = 0.75
bias = 1.0
node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize)
node = onnx.helper.make_node(
'LRN', inputs=['in'], outputs=['out'], size=nsize)
else:
node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha,
beta=beta, bias=bias, size=nsize)
graph = helper.make_graph([node],
"lrn_test",
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
inputs=[helper.make_tensor_value_info(
"in", TensorProto.FLOAT, list(shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
model = helper.make_model(graph, producer_name='lrn_test')
def _get_python_lrn():
square_sum = np.zeros(shape).astype(dtype)
for n, c, h, w in np.ndindex(in_array.shape):
square_sum[n, c, h, w] = sum(in_array[n,
max(0, c - int(math.floor((nsize - 1) / 2))): \
max(0, c - int(math.floor((nsize - 1) / 2))):
min(5, c + int(math.ceil((nsize - 1) / 2)) + 1),
h,
w] ** 2)
......@@ -408,7 +560,8 @@ def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
for target, ctx in ctx_list():
input_name = model.graph.input[0].name
py_out = _get_python_lrn()
tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, 'float32')
tvm_out = get_tvm_output(
model, in_array, target, ctx, py_out.shape, 'float32')
tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5)
......@@ -444,12 +597,14 @@ def verify_instance_norm(shape, axis=1):
graph = helper.make_graph([node],
"instance_norm_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)),
helper.make_tensor_value_info(
"gamma", TensorProto.FLOAT, (shape[1],)),
helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))])
model = helper.make_model(graph, producer_name='instance_norm_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32')
tvm_out = get_tvm_output(
model, [x, gamma, beta], target, ctx, shape, 'float32')
tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)
......@@ -464,103 +619,122 @@ def _test_upsample_nearest():
scale = 2
in_shape = (1, 1, 3, 3)
out_shape = (1, 1, 3*scale, 3*scale)
y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
y = helper.make_node("Upsample", ['in'], [
'out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW")
out_array = topi.testing.upsampling_python(
in_array, (scale, scale), "NCHW")
graph = helper.make_graph([y],
'upsample_nearest_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
inputs=[helper.make_tensor_value_info(
"in", TensorProto.FLOAT, list(in_shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='upsample_nearest_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
tvm_out = get_tvm_output(
model, in_array, target, ctx, out_shape, 'float32')
tvm.testing.assert_allclose(out_array, tvm_out)
def _test_upsample_bilinear():
scale = 2
in_shape = (1, 1, 3, 3)
out_shape = (1, 1, 3*scale, 3*scale)
y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
y = helper.make_node("Upsample", ['in'], [
'out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW")
out_array = topi.testing.bilinear_resize_python(
in_array, (3*scale, 3*scale), "NCHW")
graph = helper.make_graph([y],
'upsample_bilinear_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
inputs=[helper.make_tensor_value_info(
"in", TensorProto.FLOAT, list(in_shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='upsample_bilinear_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
tvm_out = get_tvm_output(
model, in_array, target, ctx, out_shape, 'float32')
tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
def _test_upsample_bilinear_opset9():
scale = 2
in_shape = (1, 1, 3, 3)
out_shape = (1, 1, 3*scale, 3*scale)
y = helper.make_node("Upsample", ['in','scales'], ['out'], mode='linear')
scales=[1.0, 1.0, 2.0, 2.0]
y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear')
scales = [1.0, 1.0, 2.0, 2.0]
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW")
out_array = topi.testing.bilinear_resize_python(
in_array, (3*scale, 3*scale), "NCHW")
ref_array = np.array(scales)
ref_node = helper.make_node('Constant',
inputs=[],
outputs=['scales'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = TensorProto.FLOAT,
dims = ref_array.shape,
vals = ref_array.flatten().astype(float)))
value=onnx.helper.make_tensor(name='const_tensor',
data_type=TensorProto.FLOAT,
dims=ref_array.shape,
vals=ref_array.flatten().astype(float)))
graph = helper.make_graph([ref_node, y],
'upsample_bilinear_opset9_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
inputs=[helper.make_tensor_value_info(
"in", TensorProto.FLOAT, list(in_shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='upsample_bilinear_opset9_test')
model = helper.make_model(
graph, producer_name='upsample_bilinear_opset9_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
tvm_out = get_tvm_output(
model, in_array, target, ctx, out_shape, 'float32')
tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
def test_upsample():
_test_upsample_nearest()
_test_upsample_bilinear()
_test_upsample_bilinear_opset9()
def _test_softmax(inshape, axis):
opname = 'Softmax'
indata = np.random.uniform(size=inshape).astype(np.float32)
outshape = inshape
outdata = topi.testing.softmax_python(indata)
if isinstance(axis, int):
y = helper.make_node(opname, ['in'], ['out'], axis = axis)
y = helper.make_node(opname, ['in'], ['out'], axis=axis)
elif axis is None:
y = helper.make_node(opname, ['in'], ['out'])
graph = helper.make_graph([y],
opname+'_test',
inputs = [helper.make_tensor_value_info("in",
inputs=[helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name=opname+'_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outshape, 'float32')
tvm_out = get_tvm_output(
model, indata, target, ctx, outshape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_softmax():
_test_softmax((1, 10), None)
_test_softmax((1, 10), 1)
def verify_min(input_dim):
dtype = 'float32'
......@@ -574,25 +748,28 @@ def verify_min(input_dim):
graph = helper.make_graph([min_node],
"Min_test",
inputs = [helper.make_tensor_value_info("a_np1",
inputs=[helper.make_tensor_value_info("a_np1",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np2",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np3",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(b_np.shape))])
model = helper.make_model(graph, producer_name='Min_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm_out = get_tvm_output(
model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_min():
verify_min((1, 3, 20, 20))
verify_min((20, 20))
def verify_max(input_dim):
dtype = 'float32'
......@@ -606,25 +783,28 @@ def verify_max(input_dim):
graph = helper.make_graph([max_node],
"Max_test",
inputs = [helper.make_tensor_value_info("a_np1",
inputs=[helper.make_tensor_value_info("a_np1",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np2",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np3",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(b_np.shape))])
model = helper.make_model(graph, producer_name='Max_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm_out = get_tvm_output(
model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_max():
verify_max((1, 3, 20, 20))
verify_max((20, 20))
def verify_mean(input_dim):
dtype = 'float32'
......@@ -638,25 +818,28 @@ def verify_mean(input_dim):
graph = helper.make_graph([mean_node],
"Mean_test",
inputs = [helper.make_tensor_value_info("a_np1",
inputs=[helper.make_tensor_value_info("a_np1",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np2",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np3",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(b_np.shape))])
model = helper.make_model(graph, producer_name='Mean_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm_out = get_tvm_output(
model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_mean():
verify_mean((1, 3, 20, 20))
verify_mean((20, 20))
def verify_hardsigmoid(input_dim, alpha, beta):
dtype = 'float32'
......@@ -664,13 +847,14 @@ def verify_hardsigmoid(input_dim, alpha, beta):
b_np = np.clip(a_np1 * alpha + beta, 0, 1)
hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta)
hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], [
"out"], alpha=alpha, beta=beta)
graph = helper.make_graph([hardsigmoid_node],
"HardSigmoid_test",
inputs = [helper.make_tensor_value_info("a_np1",
inputs=[helper.make_tensor_value_info("a_np1",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(b_np.shape))])
model = helper.make_model(graph, producer_name='HardSigmoid_test')
......@@ -679,10 +863,12 @@ def verify_hardsigmoid(input_dim, alpha, beta):
tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_hardsigmoid():
verify_hardsigmoid((1, 3, 20, 20), 0.5, 0.6)
verify_hardsigmoid((20, 20), 0.3, 0.4)
def verify_argmin(input_dim, axis=None, keepdims=None):
def _argmin_numpy(data, axis=0, keepdims=True):
result = np.argmin(data, axis=axis)
......@@ -717,17 +903,19 @@ def verify_argmin(input_dim, axis=None, keepdims=None):
keepdims=keepdims)
graph = helper.make_graph([node],
"argmin_test",
inputs = [helper.make_tensor_value_info("a_np1",
inputs=[helper.make_tensor_value_info("a_np1",
TensorProto.INT32, list(a_np1.shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.INT32, list(b_np.shape))])
model = helper.make_model(graph, producer_name='argmin_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
tvm_out = get_tvm_output(
model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def verify_argmax(input_dim, axis=None, keepdims=None):
def _argmax_numpy(data, axis=0, keepdims=True):
result = np.argmax(data, axis=axis)
......@@ -763,66 +951,72 @@ def verify_argmax(input_dim, axis=None, keepdims=None):
graph = helper.make_graph([node],
"argmax_test",
inputs = [helper.make_tensor_value_info("a_np1",
inputs=[helper.make_tensor_value_info("a_np1",
TensorProto.INT32, list(a_np1.shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.INT32, list(b_np.shape))])
model = helper.make_model(graph, producer_name='argmax_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
tvm_out = get_tvm_output(
model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_arg_min_max():
'''Verify argmin and argmax'''
verify_argmin([3,4,4])
verify_argmax([3,4,4])
verify_argmin([3,4,4], axis=1)
verify_argmax([3,4,4], axis=0)
verify_argmin([3,4,4], keepdims=0)
verify_argmax([3,4,4], keepdims=1)
for axis in [None, 0,1,2]:
for keepdims in [None, True,False]:
verify_argmin([3,4,4], axis, keepdims)
verify_argmax([3,4,4], axis, keepdims)
def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs):
input_a = np.random.uniform(size=input_dim).astype(dtype)
out = np.empty(shape=out_dim, dtype=dtype)
verify_argmin([3, 4, 4])
verify_argmax([3, 4, 4])
verify_argmin([3, 4, 4], axis=1)
verify_argmax([3, 4, 4], axis=0)
verify_argmin([3, 4, 4], keepdims=0)
verify_argmax([3, 4, 4], keepdims=1)
for axis in [None, 0, 1, 2]:
for keepdims in [None, True, False]:
verify_argmin([3, 4, 4], axis, keepdims)
verify_argmax([3, 4, 4], axis, keepdims)
def verify_constantofshape(input_dim, value, dtype):
out = np.empty(shape=input_dim, dtype=dtype)
out.fill(value)
if is_shape == True:
fill_node = helper.make_node("ConstantFill", [], ["out"], shape=input_dim, value=value, **kwargs)
else:
fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs)
fill_node = helper.make_node("ConstantOfShape", ["input"], ["output"],
value=helper.make_tensor(
'value',
mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
(1, ), (value, )))
if is_shape == True:
inputs = []
else:
inputs = [helper.make_tensor_value_info("input_a",
TensorProto.FLOAT, list(input_dim))]
inputs = [
helper.make_tensor_value_info("input", TensorProto.FLOAT, input_dim)
]
graph = helper.make_graph([fill_node],
graph = helper.make_graph(
[fill_node],
"fill_test",
inputs,
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out.shape))])
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT,
list(out.shape))
],
initializer=[
helper.make_tensor("input", TensorProto.INT32, (len(input_dim), ),
input_dim)
])
model = helper.make_model(graph, producer_name='fill_test')
for target, ctx in ctx_list():
if is_shape == True:
tvm_out = get_tvm_output(model, [], target, ctx, out.shape)
else:
tvm_out = get_tvm_output(model, [input_a], target, ctx, out.shape)
tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5)
def test_constantfill():
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))
def test_constantofshape():
verify_constantofshape((2, 3, 4, 5), 10, 'float32')
verify_constantofshape((3, 3), 0, 'int32')
verify_constantofshape((1, 2, 3), -1, 'float32')
def verify_pad(indata, pads, mode='constant', value=0.0):
......@@ -841,7 +1035,8 @@ def verify_pad(indata, pads, mode='constant', value=0.0):
pads=pads,
)
else:
outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
outdata = np.pad(indata, pad_width=np_pads,
mode='constant', constant_values=value)
node = helper.make_node(
'Pad',
inputs=['input'],
......@@ -852,22 +1047,30 @@ def verify_pad(indata, pads, mode='constant', value=0.0):
)
graph = helper.make_graph([node],
'pad_test',
inputs = [helper.make_tensor_value_info("input",
inputs=[helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output",
outputs=[helper.make_tensor_value_info("output",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='pad_test')
# tvm result
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
tvm_out = get_tvm_output(
model, indata, target, ctx, outdata.shape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_pad():
verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 'constant', 0.0)
verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 'constant', 0.0)
verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 'constant', 5.0)
verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge')
verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')
verify_pad(np.random.randn(2, 2).astype(
np.float32), [0, 1, 0, 0], 'constant', 0.0)
verify_pad(np.random.randn(2, 3).astype(
np.float32), [1, 0, 0, 1], 'constant', 0.0)
verify_pad(np.random.randn(3, 2).astype(
np.float32), [0, 0, 1, 0], 'constant', 5.0)
verify_pad(np.random.randn(1, 3, 4, 5).astype(
np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge')
verify_pad(np.random.randn(1, 3, 4, 5).astype(
np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')
def verify_reduce_x(name, indata, axis, keepdims):
indata = np.array(indata).astype(np.float32)
......@@ -893,16 +1096,18 @@ def verify_reduce_x(name, indata, axis, keepdims):
axes=axis, keepdims=keepdims)
graph = helper.make_graph([node],
'{}_test'.format(name),
inputs = [helper.make_tensor_value_info("input",
inputs=[helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output",
outputs=[helper.make_tensor_value_info("output",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='{}_test'.format(name))
# tvm result
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
tvm_out = get_tvm_output(
model, indata, target, ctx, outdata.shape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_reduce_max():
verify_reduce_x("ReduceMax",
np.random.randn(3, 2, 2).astype(np.float32),
......@@ -914,6 +1119,7 @@ def test_reduce_max():
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_min():
verify_reduce_x("ReduceMin",
np.random.randn(3, 2, 2).astype(np.float32),
......@@ -925,6 +1131,7 @@ def test_reduce_min():
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_sum():
verify_reduce_x("ReduceSum",
np.random.randn(3, 2, 2).astype(np.float32),
......@@ -936,6 +1143,7 @@ def test_reduce_sum():
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_mean():
verify_reduce_x("ReduceMean",
np.random.randn(3, 2, 2).astype(np.float32),
......@@ -947,40 +1155,52 @@ def test_reduce_mean():
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def verify_split(indata, outdatas, split, axis=0):
indata = np.array(indata).astype(np.float32)
outdatas = [np.array(o).astype(np.float32) for o in outdatas]
if split:
split_index = range(len(split))
else:
split_index = range(len(outdatas))
node = helper.make_node(
'Split',
inputs=['input'],
outputs=['output_{}'.format(i) for i in range(len(split))],
outputs=['output_{}'.format(i) for i in range(len(split_index))],
axis=axis,
split=split
)
graph = helper.make_graph([node],
'split_test',
inputs = [helper.make_tensor_value_info("input",
inputs=[helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output_{}".format(i),
outputs=[helper.make_tensor_value_info("output_{}".format(i),
TensorProto.FLOAT, list(outdatas[i].shape))
for i in range(len(split))
for i in range(len(split_index))
])
model = helper.make_model(graph, producer_name='split_test')
for target, ctx in ctx_list():
output_shape = [o.shape for o in outdatas]
output_type = ['float32', 'float32', 'float32']
tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type)
tvm_out = get_tvm_output(
model, indata, target, ctx, output_shape, output_type)
for o, t in zip(outdatas, tvm_out):
tvm.testing.assert_allclose(o, t)
def test_split():
# 1D
verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0)
verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0)
verify_split([1., 2., 3., 4., 5., 6.], [
[1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0)
verify_split([1., 2., 3., 4., 5., 6.], [
[1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0)
# 2D
verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]],
[[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1)
# Split evenly (unstack)
verify_split([1, 2, 3], [[1], [2], [3]], False)
def test_binary_ops():
in_shape = (1, 2, 3, 3)
......@@ -994,11 +1214,11 @@ def test_binary_ops():
z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1)
graph = helper.make_graph([z],
'_test',
inputs = [helper.make_tensor_value_info("in1",
inputs=[helper.make_tensor_value_info("in1",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("in2",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
......@@ -1008,11 +1228,11 @@ def test_binary_ops():
x = np.random.uniform(size=in_shape).astype(dtype)
y = np.random.uniform(size=in_shape).astype(dtype)
z = np.random.uniform(size=(3,)).astype(dtype)
verify_binary_ops("Add",x, y, x + y, broadcast=None)
verify_binary_ops("Add", x, y, x + y, broadcast=None)
verify_binary_ops("Add", x, z, x + z, broadcast=True)
verify_binary_ops("Sub", x, y, x - y, broadcast=None)
verify_binary_ops("Sub", x, z, x - z, broadcast=True)
verify_binary_ops("Mul",x, y, x * y, broadcast=None)
verify_binary_ops("Mul", x, y, x * y, broadcast=None)
verify_binary_ops("Mul", x, z, x * z, broadcast=True)
verify_binary_ops("Div", x, y, x / y, broadcast=None)
verify_binary_ops("Div", x, z, x / z, broadcast=True)
......@@ -1021,6 +1241,7 @@ def test_binary_ops():
verify_binary_ops("Less", x, y, x < y, broadcast=True)
verify_binary_ops("Equal", x, y, x == y, broadcast=True)
def test_single_ops():
in_shape = (1, 2, 3, 3)
dtype = "float32"
......@@ -1030,9 +1251,9 @@ def test_single_ops():
z = helper.make_node(op, ['in1'], ['out'])
graph = helper.make_graph([z],
'_test',
inputs = [helper.make_tensor_value_info("in1",
TensorProto.FLOAT, list(in_shape)),],
outputs = [helper.make_tensor_value_info("out",
inputs=[helper.make_tensor_value_info("in1",
TensorProto.FLOAT, list(in_shape)), ],
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
......@@ -1040,18 +1261,19 @@ def test_single_ops():
tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
x = np.random.uniform(size=in_shape).astype(dtype)
verify_single_ops("Neg",x, -x)
verify_single_ops("Abs",x, np.abs(x))
verify_single_ops("Reciprocal",x, 1/x)
verify_single_ops("Sqrt",x, np.sqrt(x))
verify_single_ops("Relu",x, np.maximum(x, 0))
verify_single_ops("Exp",x, np.exp(x))
verify_single_ops("Log",x, np.log(x))
verify_single_ops("Log",x, np.log(x))
verify_single_ops("Tanh",x, np.tanh(x))
verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)))
verify_single_ops("Softsign",x, x / (1 + np.abs(x)))
verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)))
verify_single_ops("Neg", x, -x)
verify_single_ops("Abs", x, np.abs(x))
verify_single_ops("Reciprocal", x, 1/x)
verify_single_ops("Sqrt", x, np.sqrt(x))
verify_single_ops("Relu", x, np.maximum(x, 0))
verify_single_ops("Exp", x, np.exp(x))
verify_single_ops("Log", x, np.log(x))
verify_single_ops("Log", x, np.log(x))
verify_single_ops("Tanh", x, np.tanh(x))
verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x)))
verify_single_ops("Softsign", x, x / (1 + np.abs(x)))
verify_single_ops("SoftPlus", x, np.log(1 + np.exp(x)))
def test_leaky_relu():
def leaky_relu_x(x, alpha):
......@@ -1063,6 +1285,7 @@ def test_leaky_relu():
'LeakyRelu',
{'alpha': 0.25})
def test_elu():
def elu_x(x, alpha):
return np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
......@@ -1073,6 +1296,7 @@ def test_elu():
'Elu',
{'alpha': 0.25})
def test_selu():
def selu_x(x, alpha, gamma):
return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
......@@ -1083,6 +1307,7 @@ def test_selu():
'Selu',
{'alpha': 0.25, 'gamma': 0.3})
def test_ThresholdedRelu():
def ThresholdedRelu_x(x, alpha):
out_np = np.clip(x, alpha, np.inf)
......@@ -1095,6 +1320,7 @@ def test_ThresholdedRelu():
'ThresholdedRelu',
{'alpha': 0.25})
def test_ScaledTanh():
def ScaledTanh_x(x, alpha, beta):
return alpha * np.tanh(beta * x)
......@@ -1105,6 +1331,7 @@ def test_ScaledTanh():
'ScaledTanh',
{'alpha': 0.25, 'beta': 0.3})
def test_ParametricSoftplus():
def ParametricSoftplus_x(x, alpha, beta):
return alpha * np.log(np.exp(beta * x) + 1)
......@@ -1115,6 +1342,7 @@ def test_ParametricSoftplus():
'ParametricSoftplus',
{'alpha': 0.25, 'beta': 0.3})
def test_Scale():
def Scale_x(x, scale):
return scale * x
......@@ -1125,6 +1353,7 @@ def test_Scale():
'Scale',
{'scale': 0.25})
def test_LogSoftmax():
_test_onnx_op_elementwise((1, 4),
topi.testing.log_softmax_python,
......@@ -1138,7 +1367,8 @@ def check_torch_conversion(model, input_size):
dummy_input = torch.randn(*input_size)
file_name = '{}.onnx'.format(model.__name__)
# Set verbose=True for more output
torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
torch.onnx.export(model(), dummy_input, file_name,
export_params=True, verbose=False)
onnx_model = onnx.load(file_name)
for target, ctx in ctx_list():
input_data = np.random.uniform(size=input_size).astype('int32')
......@@ -1146,13 +1376,14 @@ def check_torch_conversion(model, input_size):
tvm_out = get_tvm_output(onnx_model, input_data, target, ctx)
tvm.testing.assert_allclose(c2_out, tvm_out)
def test_resnet():
check_torch_conversion(torchvision.models.resnet18, (1,3,224,224))
check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224))
# check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))
# def test_alexnet():
# Torch's ONNX export does not support the adaptive pooling used by AlexNet?
# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
# Torch's ONNX export does not support the adaptive pooling used by AlexNet?
# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
# Torch's ONNX export does not support the adaptive pooling used by vgg16?
# def test_vgg16():
......@@ -1163,11 +1394,13 @@ def test_resnet():
# # Torch's ONNX export does not support the max pooling used by Squezenet
# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))
def test_densenet():
check_torch_conversion(torchvision.models.densenet161, (1,3,224,224))
check_torch_conversion(torchvision.models.densenet161, (1, 3, 224, 224))
def test_inception():
check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224))
check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224))
# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_googlenet():
......@@ -1177,6 +1410,7 @@ def test_inception():
# def test_shufflenetv2():
# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))
def test_sign():
def Sign_x(x):
return np.sign(x)
......@@ -1196,7 +1430,8 @@ def verify_not(indata, dtype):
graph = helper.make_graph([node],
'not_test',
inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))],
inputs=[helper.make_tensor_value_info(
"in", TensorProto.BOOL, list(x.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
model = helper.make_model(graph, producer_name='not_test')
......@@ -1262,31 +1497,70 @@ def test_and():
verify_and(indata=[x, y], dtype=bool)
def verify_tile(indata, outdata, **kwargs):
def verify_tile_v1(indata, outdata, **kwargs):
node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs)
graph = helper.make_graph([node],
'tile_test',
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
inputs=[helper.make_tensor_value_info(
"in", TensorProto.FLOAT, list(indata.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='tile_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
tvm_out = get_tvm_output(
model, [indata], target, ctx, outdata.shape, opset=1)
tvm.testing.assert_allclose(outdata, tvm_out)
def verify_tile_v6(indata, repeats, outdata):
node = helper.make_node('Tile',
inputs=['input', 'repeats'],
outputs=['out'])
graph = helper.make_graph(
[node],
'tile_test',
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT,
list(indata.shape)),
helper.make_tensor_value_info("repeats", TensorProto.INT64,
list(repeats.shape))
],
outputs=[
helper.make_tensor_value_info("out", TensorProto.FLOAT,
list(outdata.shape))
],
initializer=[
helper.make_tensor("repeats", TensorProto.INT64,
list(repeats.shape), repeats)
])
model = helper.make_model(graph, producer_name='tile_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [indata],
target,
ctx,
outdata.shape,
opset=6)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_tile():
x = np.random.rand(2, 3, 4, 5).astype(np.float32)
repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
repeats = np.random.randint(
low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
z = np.tile(x, repeats)
verify_tile(x, z, repeats=repeats)
verify_tile_v1(x, z, repeats=repeats)
verify_tile_v6(x, repeats, z)
def verify_erf(indata, outdata):
node = helper.make_node('Erf', inputs=['in'], outputs=['out'])
graph = helper.make_graph([node],
'erf_test',
inputs=[helper.make_tensor_value_info('in', TensorProto.FLOAT, list(indata.shape))],
inputs=[helper.make_tensor_value_info(
'in', TensorProto.FLOAT, list(indata.shape))],
outputs=[helper.make_tensor_value_info('out', TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='erf_test')
......@@ -1294,6 +1568,7 @@ def verify_erf(indata, outdata):
tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_erf():
x = np.random.rand(2, 3, 4, 6).astype(np.float32)
z = scipy.special.erf(x)
......@@ -1337,7 +1612,9 @@ if __name__ == '__main__':
test_floor()
test_ceil()
test_clip()
test_onehot()
test_matmul()
test_batch_matmul()
test_gather()
test_lrn()
test_instance_norm()
......@@ -1348,7 +1625,7 @@ if __name__ == '__main__':
test_forward_hardsigmoid()
test_forward_arg_min_max()
test_softmax()
test_constantfill()
test_constantofshape()
test_reduce_max()
test_reduce_min()
test_reduce_sum()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment