Unverified Commit 967d7318 by Samuel Committed by GitHub

[MXNET]broadcast and logical op support (#5461)

* [MXNET]broadcast and logical op support

* Review comment fixed
parent 3f33b254
...@@ -1712,6 +1712,33 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params): ...@@ -1712,6 +1712,33 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
res = _op.nn.relu(res) res = _op.nn.relu(res)
return res return res
def _mx_broadcast_to(inputs, attrs):
data = inputs[0]
tgt_shape = attrs.get_int_tuple("shape", [])
return _op.broadcast_to(data, tgt_shape)
def _mx_logical_not(inputs, input_types):
data = inputs[0]
dtype = _infer_type(data).checked_type.dtype
data = _op.cast(data, "bool") if dtype != "bool" else data
return _op.cast(_op.logical_not(data), dtype)
def _mx_broadcast_logical(logical_op):
def impl(inputs, input_types):
lhs_type = _infer_type(inputs[0]).checked_type.dtype
rhs_type = _infer_type(inputs[1]).checked_type.dtype
lhs = _op.cast(inputs[0], "bool") if lhs_type != "bool" else inputs[0]
rhs = _op.cast(inputs[1], "bool") if rhs_type != "bool" else inputs[1]
return _op.cast(logical_op(lhs, rhs), lhs_type)
return impl
# Note: due to attribute conversion constraint # Note: due to attribute conversion constraint
# ops in the identity set must be attribute free # ops in the identity set must be attribute free
_identity_list = [ _identity_list = [
...@@ -1738,12 +1765,15 @@ _convert_map = { ...@@ -1738,12 +1765,15 @@ _convert_map = {
"_copy" : _rename(_op.copy), "_copy" : _rename(_op.copy),
"relu" : _rename(_op.nn.relu), "relu" : _rename(_op.nn.relu),
"broadcast_add" : _rename(_op.add), "broadcast_add" : _rename(_op.add),
"broadcast_plus" : _rename(_op.add),
"broadcast_sub" : _rename(_op.subtract), "broadcast_sub" : _rename(_op.subtract),
"broadcast_minus" : _rename(_op.subtract),
"broadcast_mul" : _rename(_op.multiply), "broadcast_mul" : _rename(_op.multiply),
"broadcast_div" : _rename(_op.divide), "broadcast_div" : _rename(_op.divide),
"broadcast_mod" : _rename(_op.mod), "broadcast_mod" : _rename(_op.mod),
"broadcast_maximum" : _rename(_op.maximum), "broadcast_maximum" : _rename(_op.maximum),
"broadcast_minimum" : _rename(_op.minimum), "broadcast_minimum" : _rename(_op.minimum),
"broadcast_power" : _rename(_op.power),
"arctan" : _rename(_op.atan), "arctan" : _rename(_op.atan),
"broadcast_equal" : _mx_compare(_op.equal, _rename), "broadcast_equal" : _mx_compare(_op.equal, _rename),
"broadcast_not_equal" : _mx_compare(_op.not_equal, _rename), "broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
...@@ -1751,6 +1781,11 @@ _convert_map = { ...@@ -1751,6 +1781,11 @@ _convert_map = {
"broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename), "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
"broadcast_lesser" : _mx_compare(_op.less, _rename), "broadcast_lesser" : _mx_compare(_op.less, _rename),
"broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename), "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
"broadcast_logical_or" : _mx_broadcast_logical(_op.logical_or),
"broadcast_logical_and" : _mx_broadcast_logical(_op.logical_and),
"broadcast_logical_xor" : _mx_broadcast_logical(_op.logical_xor),
"broadcast_to" : _mx_broadcast_to,
"logical_not" : _mx_logical_not,
"_equal" : _mx_compare(_op.equal, _rename), "_equal" : _mx_compare(_op.equal, _rename),
"_not_equal" : _mx_compare(_op.not_equal, _rename), "_not_equal" : _mx_compare(_op.not_equal, _rename),
"_greater" : _mx_compare(_op.greater, _rename), "_greater" : _mx_compare(_op.greater, _rename),
...@@ -1860,6 +1895,7 @@ _convert_map = { ...@@ -1860,6 +1895,7 @@ _convert_map = {
"reverse" : _mx_reverse, "reverse" : _mx_reverse,
"squeeze" : _mx_squeeze, "squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis, "broadcast_axis": _mx_broadcast_axis,
"broadcast_axes": _mx_broadcast_axis,
"BlockGrad" : _mx_BlockGrad, "BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array, "shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding, "Embedding" : _mx_embedding,
...@@ -1897,7 +1933,6 @@ _convert_map = { ...@@ -1897,7 +1933,6 @@ _convert_map = {
# List of missing operators that are present in NNVMv1 # List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators. # TODO(tvm-tvm): support all operators.
# #
# "broadcast_to",
# "contrib_fifo_buffer": _mx_contrib_fifo_buffer, # "contrib_fifo_buffer": _mx_contrib_fifo_buffer,
"ring_buffer": _mx_contrib_fifo_buffer, "ring_buffer": _mx_contrib_fifo_buffer,
# Qnn ops # Qnn ops
......
...@@ -301,11 +301,25 @@ def _mx_symbol(F, op_name, inputs): ...@@ -301,11 +301,25 @@ def _mx_symbol(F, op_name, inputs):
return op(*inputs) return op(*inputs)
def test_forward_broadcast_ops(): def test_forward_broadcast_ops():
for op in ["broadcast_add", "broadcast_sub", "broadcast_mul", for op in ["broadcast_add",
"broadcast_div", "broadcast_mod", "broadcast_maximum", "broadcast_plus",
"broadcast_minimum", "broadcast_equal", "broadcast_not_equal", "broadcast_sub",
"broadcast_greater", "broadcast_greater_equal", "broadcast_minus",
"broadcast_lesser", "broadcast_lesser_equal"]: "broadcast_mul",
"broadcast_div",
"broadcast_mod",
"broadcast_maximum",
"broadcast_minimum",
"broadcast_equal",
"broadcast_not_equal",
"broadcast_greater",
"broadcast_greater_equal",
"broadcast_lesser",
"broadcast_lesser_equal",
"broadcast_power",
"broadcast_logical_or",
"broadcast_logical_and",
"broadcast_logical_xor"]:
a_shape = (3, 4, 5) a_shape = (3, 4, 5)
b_shape = (4, 5) b_shape = (4, 5)
if op == "broadcast_mod": if op == "broadcast_mod":
...@@ -462,17 +476,52 @@ def test_forward_squeeze(): ...@@ -462,17 +476,52 @@ def test_forward_squeeze():
def test_forward_broadcast_axis(): def test_forward_broadcast_axis():
def verify(shape, axis, size): def verify(shape, axis, size):
x_np = np.random.uniform(size=shape).astype("float32") x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size) for op in ["broadcast_axis",
mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size) "broadcast_axes"]:
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('x'),axis,size])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(x_np),axis,size])
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list(): for target, ctx in ctx_list():
for kind in ["graph", "debug"]: for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_np) op_res = intrp.evaluate()(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 2, 1), 2, 3) verify((1, 2, 1), 2, 3)
verify((1, 2, 1), (0, 2), (2, 3)) verify((1, 2, 1), (0, 2), (2, 3))
def test_forward_broadcast_to():
def verify(input_shape, shape):
x_np = np.random.uniform(size=input_shape).astype("float32")
ref_res = mx.nd.broadcast_to(mx.nd.array(x_np), shape=shape)
mx_sym = mx.sym.broadcast_to(mx.sym.var("x"), shape=shape)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 2, 3), (3, 2, 3))
verify((4, 1, 32, 32), (4, 8, 32, 32))
def test_forward_logical_not():
a_shape = (3, 4, 5)
dtype = 'float32'
a_np = np.random.uniform(size=a_shape).astype(dtype)
mx_sym = mx.sym.logical_not(mx.sym.var('a'))
ref_res = mx.nd.logical_not(mx.nd.array(a_np))
shapes = {'a': a_shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
def test_forward_full(): def test_forward_full():
def verify(val, shape, dtype): def verify(val, shape, dtype):
ctx = mx.cpu() ctx = mx.cpu()
...@@ -1061,6 +1110,8 @@ if __name__ == '__main__': ...@@ -1061,6 +1110,8 @@ if __name__ == '__main__':
test_forward_where() test_forward_where()
test_forward_arange() test_forward_arange()
test_forward_broadcast_ops() test_forward_broadcast_ops()
test_forward_broadcast_to()
test_forward_logical_not()
test_forward_elemwise_ops() test_forward_elemwise_ops()
test_forward_scalar_ops() test_forward_scalar_ops()
test_forward_slice_like() test_forward_slice_like()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment