Commit 846d9ce0 by Siva Committed by Yizhi Liu

[ONNX][FRONTEND] Constantfill - #1539 (#1764)

parent b14bb7f9
......@@ -611,6 +611,46 @@ class Softmax(OnnxOpConverter):
'axis': ('axis', 1),
})(inputs, attr, params)
class ConstantFill(OnnxOpConverter):
""" Operator converter for ConstantFill.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
is_full = True
num_inputs = len(inputs)
if 'shape' in attr:
if num_inputs > 0:
raise ImportError(
"Can't set shape and input tensor at a time")
shape = attr.pop('shape')
else:
if num_inputs == 0:
raise ImportError(
"Either shape attribute or input should be set")
if 'input_as_shape' in attr and attr['input_as_shape']:
shape = params[inputs[0].list_output_names()[0]].asnumpy()
else:
is_full = False
if not is_full:
if 'extra_shape' in attr:
raise ImportError(
"Extra Shape not supported with fill_like")
out = AttrCvt(
op_name='full_like',
transforms={'value': 'fill_value'},
ignores=['dtype'])(inputs, attr)
return _sym.cast(out, dtype=attr['dtype'].decode("utf-8"))
else:
if 'extra_shape' in attr:
shape = shape + attr.pop('extra_shape')
return AttrCvt(
op_name='full',
transforms={'value': 'fill_value'},
extras={'shape':shape})(inputs, attr)
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -628,7 +668,7 @@ def _get_convert_map(opset):
'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
'ScaledTanh': ScaledTanh.get_converter(opset),
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
# 'ConstantFill'
'ConstantFill': ConstantFill.get_converter(opset),
# 'GivenTensorFill'
'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
'Scale': Scale.get_converter(opset),
......
......@@ -680,6 +680,38 @@ def test_forward_arg_min_max():
verify_argmin([3,4,4], axis, keepdims)
verify_argmax([3,4,4], axis, keepdims)
def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs):
input_a = np.random.uniform(size=input_dim).astype(dtype)
out = np.empty(shape=out_dim, dtype=dtype)
out.fill(value)
if is_shape == True:
fill_node = helper.make_node("ConstantFill", [], ["out"], shape=input_dim, value=value, **kwargs)
else:
fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs)
graph = helper.make_graph([fill_node],
"fill_test",
inputs = [helper.make_tensor_value_info("input_a",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out.shape))])
model = helper.make_model(graph, producer_name='fill_test')
for target, ctx in ctx_list():
if is_shape == True:
tvm_out = get_tvm_output(model, [], target, ctx, out.shape)
else:
tvm_out = get_tvm_output(model, [input_a], target, ctx, out.shape)
np.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5)
def test_constantfill():
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))
if __name__ == '__main__':
# verify_super_resolution_example()
# verify_squeezenet1_1()
......@@ -704,3 +736,4 @@ if __name__ == '__main__':
test_forward_hardsigmoid()
test_forward_arg_min_max()
test_softmax()
test_constantfill()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment