Commit 0482623e by Jon Soifer Committed by Zhi

[Relay][TensorFlow] Add support for SquaredDifference (#3930)

* Add support for SquaredDifference and StopGradient; minor fix in BatchMatMul

* Remove stopgradient change

* Resolve PR comment

* Dummy change to retrigger CI

* dummy change to retrigger CI
parent da039794
......@@ -469,7 +469,9 @@ def _batch_matmul():
# reshape result back to n-dimensional
if len(orig_shape_x) > 3:
final_shape = attr['_output_shapes'][0]
final_shape = list(orig_shape_x)
final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2]
final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1]
ret = _op.reshape(ret, newshape=final_shape)
return ret
......@@ -1227,6 +1229,12 @@ def _one_hot():
extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
return _impl
def _squared_difference():
def _impl(inputs, attr, params):
difference = _op.subtract(inputs[0], inputs[1])
return _op.multiply(difference, difference)
return _impl
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -1334,6 +1342,7 @@ _convert_map = {
'SplitV' : _split(True),
'Sqrt' : AttrCvt('sqrt'),
'Square' : _square(),
'SquaredDifference' : _squared_difference(),
'Squeeze' : _squeeze(),
'StridedSlice' : _stridedSlice(),
'Sub' : _elemwise('subtract'),
......
......@@ -1852,6 +1852,16 @@ def test_forward_erf():
tf.math.erf(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')
def test_forward_squared_difference():
ishape = (1, 3, 10, 14)
inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array_a.shape, dtype=inp_array_a.dtype, name="in1")
in2 = tf.placeholder(shape=inp_array_b.shape, dtype=inp_array_b.dtype, name="in2")
out = tf.math.squared_difference(in1, in2)
compare_tf_with_tvm([inp_array_a, inp_array_b], [in1.name, in2.name], out.name)
def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
......@@ -2253,6 +2263,7 @@ if __name__ == '__main__':
test_forward_bias_add()
test_forward_zeros_like()
test_forward_erf()
test_forward_squared_difference()
# Reductions
test_forward_argminmax()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment