Commit 113b46ec by Siva Committed by Tianqi Chen

[NNVM][ONNX] Shape operator support (limited/differed) - #1297 (#1333)

parent 373a8caa
...@@ -258,10 +258,11 @@ class Reshape(OnnxOpConverter): ...@@ -258,10 +258,11 @@ class Reshape(OnnxOpConverter):
def _impl_v5(cls, inputs, attr, params): def _impl_v5(cls, inputs, attr, params):
if inputs[1].list_output_names()[0] in params: if inputs[1].list_output_names()[0] in params:
shape = tuple(params[inputs[1].list_output_names()[0]].asnumpy()) shape = tuple(params[inputs[1].list_output_names()[0]].asnumpy())
out = _sym.reshape(inputs[0], shape=shape)
else: else:
raise RuntimeError('Shape is not contained in graph initializer.') out = _sym.reshape_like(inputs[0], inputs[1])
return _sym.reshape(inputs[0], shape=shape)
return out
class Scale(OnnxOpConverter): class Scale(OnnxOpConverter):
...@@ -405,6 +406,36 @@ def _fully_connected(opset): ...@@ -405,6 +406,36 @@ def _fully_connected(opset):
return _impl return _impl
class Shape(OnnxOpConverter):
""" Operator converter for Shape.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# Result of this operator is prominently used by reshape operator.
# Just pass the input as it is so that reshape_like can be used there.
print("Shape: Differently implemented in NNVM as a bypass (dummy operator)")
return inputs[0]
class Cast(OnnxOpConverter):
""" Operator converter for Cast.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
@classmethod
def _impl_v5(cls, inputs, attr, params):
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
except ImportError as e:
raise ImportError(
"Unable to import onnx.mapping which is required {}".format(e))
return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -505,7 +536,7 @@ def _get_convert_map(opset): ...@@ -505,7 +536,7 @@ def _get_convert_map(opset):
# 'ArgMin' # 'ArgMin'
# defs/tensor # defs/tensor
'Cast': AttrCvt('cast', {'to': 'dtype'}), 'Cast': Cast.get_converter(opset),
'Reshape': Reshape.get_converter(opset), 'Reshape': Reshape.get_converter(opset),
'Concat': Renamer('concatenate'), 'Concat': Renamer('concatenate'),
'Split': AttrCvt('split', {'split': 'indices_or_sections'}), 'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
...@@ -514,6 +545,7 @@ def _get_convert_map(opset): ...@@ -514,6 +545,7 @@ def _get_convert_map(opset):
# 'Gather' # 'Gather'
# 'Squeeze' # 'Squeeze'
'Pad': Pad.get_converter(opset), 'Pad': Pad.get_converter(opset),
'Shape': Shape.get_converter(opset),
} }
...@@ -719,6 +751,9 @@ def from_onnx(model): ...@@ -719,6 +751,9 @@ def from_onnx(model):
""" """
g = GraphProto() g = GraphProto()
graph = model.graph graph = model.graph
try:
opset = model.opset_import[0].version if model.opset_import else 1 opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset = 1
sym, params = g.from_onnx(graph, opset) sym, params = g.from_onnx(graph, opset)
return sym, params return sym, params
...@@ -5,20 +5,14 @@ from tvm.contrib import graph_runtime ...@@ -5,20 +5,14 @@ from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
import onnx import onnx
from model_zoo import super_resolution, squeezenet1_1, lenet, resnet18_1_0 from model_zoo import super_resolution, squeezenet1_1, lenet, resnet18_1_0
from onnx import helper, TensorProto
def verify_onnx_forward_impl(graph_file, data_shape, out_shape): def get_tvm_output(model, x, target, ctx, out_shape, dtype='float32'):
import caffe2.python.onnx.backend
def get_caffe2_output(model, x, dtype='float32'):
prepared_backend = caffe2.python.onnx.backend.prepare(model)
W = {model.graph.input[0].name: x.astype(dtype)}
c2_out = prepared_backend.run(W)[0]
return c2_out
def get_tvm_output(model, x, target, ctx, dtype='float32'):
new_sym, params = nnvm.frontend.from_onnx(model) new_sym, params = nnvm.frontend.from_onnx(model)
input_name = model.graph.input[0].name input_name = model.graph.input[0].name
shape_dict = {input_name: x.shape} shape_dict = {input_name: x.shape}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params) dtype_dict = {input_name: dtype}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, dtype_dict, params=params)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set inputs # set inputs
m.set_input(input_name, tvm.nd.array(x.astype(dtype))) m.set_input(input_name, tvm.nd.array(x.astype(dtype)))
...@@ -28,12 +22,21 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): ...@@ -28,12 +22,21 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
out = m.get_output(0, tvm.nd.empty(out_shape, dtype)) out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy() return out.asnumpy()
def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
import caffe2.python.onnx.backend
def get_caffe2_output(model, x, dtype='float32'):
prepared_backend = caffe2.python.onnx.backend.prepare(model)
W = {model.graph.input[0].name: x.astype(dtype)}
c2_out = prepared_backend.run(W)[0]
return c2_out
dtype = 'float32' dtype = 'float32'
x = np.random.uniform(size=data_shape) x = np.random.uniform(size=data_shape)
model = onnx.load(graph_file) model = onnx.load(graph_file)
c2_out = get_caffe2_output(model, x, dtype) c2_out = get_caffe2_output(model, x, dtype)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, x, target, ctx, dtype) tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
np.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def verify_super_resolution_example(): def verify_super_resolution_example():
...@@ -48,8 +51,66 @@ def verify_lenet(): ...@@ -48,8 +51,66 @@ def verify_lenet():
def verify_resnet18(): def verify_resnet18():
verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000)) verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))
def test_reshape():
in_shape = (4, 3, 3, 4)
ref_shape = (3, 4, 4, 3)
ref_array = np.array(ref_shape)
ref_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['ref_in'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.INT32,
dims = ref_array.shape,
vals = ref_array.flatten().astype(int)))
reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
graph = helper.make_graph([ref_node, reshape_node],
"reshape_test",
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='reshape_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape)
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
np.testing.assert_allclose(ref_shape, tvm_out.shape)
def test_reshape_like():
in_shape = (4, 3, 3, 4)
ref_shape = (3, 4, 4, 3)
ref_array = np.random.uniform(size=ref_shape).astype('float32')
ref_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['ref_in'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.FLOAT,
dims = ref_array.shape,
vals = ref_array.flatten().astype(float)))
copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"])
reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"])
graph = helper.make_graph([ref_node, copy_node, reshape_node],
"reshape_like_test",
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='reshape_like_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape)
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
np.testing.assert_allclose(ref_shape, tvm_out.shape)
if __name__ == '__main__': if __name__ == '__main__':
# verify_super_resolution_example() # verify_super_resolution_example()
# verify_squeezenet1_1() # verify_squeezenet1_1()
# verify_lenet() # verify_lenet()
verify_resnet18() verify_resnet18()
test_reshape()
test_reshape_like()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment