Unverified Commit 967d7318 by Samuel Committed by GitHub

[MXNET]broadcast and logical op support (#5461)

* [MXNET]broadcast and logical op support

* Review comment fixed
parent 3f33b254
......@@ -1712,6 +1712,33 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
res = _op.nn.relu(res)
return res
def _mx_broadcast_to(inputs, attrs):
data = inputs[0]
tgt_shape = attrs.get_int_tuple("shape", [])
return _op.broadcast_to(data, tgt_shape)
def _mx_logical_not(inputs, input_types):
data = inputs[0]
dtype = _infer_type(data).checked_type.dtype
data = _op.cast(data, "bool") if dtype != "bool" else data
return _op.cast(_op.logical_not(data), dtype)
def _mx_broadcast_logical(logical_op):
def impl(inputs, input_types):
lhs_type = _infer_type(inputs[0]).checked_type.dtype
rhs_type = _infer_type(inputs[1]).checked_type.dtype
lhs = _op.cast(inputs[0], "bool") if lhs_type != "bool" else inputs[0]
rhs = _op.cast(inputs[1], "bool") if rhs_type != "bool" else inputs[1]
return _op.cast(logical_op(lhs, rhs), lhs_type)
return impl
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -1738,12 +1765,15 @@ _convert_map = {
"_copy" : _rename(_op.copy),
"relu" : _rename(_op.nn.relu),
"broadcast_add" : _rename(_op.add),
"broadcast_plus" : _rename(_op.add),
"broadcast_sub" : _rename(_op.subtract),
"broadcast_minus" : _rename(_op.subtract),
"broadcast_mul" : _rename(_op.multiply),
"broadcast_div" : _rename(_op.divide),
"broadcast_mod" : _rename(_op.mod),
"broadcast_maximum" : _rename(_op.maximum),
"broadcast_minimum" : _rename(_op.minimum),
"broadcast_power" : _rename(_op.power),
"arctan" : _rename(_op.atan),
"broadcast_equal" : _mx_compare(_op.equal, _rename),
"broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
......@@ -1751,6 +1781,11 @@ _convert_map = {
"broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
"broadcast_lesser" : _mx_compare(_op.less, _rename),
"broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
"broadcast_logical_or" : _mx_broadcast_logical(_op.logical_or),
"broadcast_logical_and" : _mx_broadcast_logical(_op.logical_and),
"broadcast_logical_xor" : _mx_broadcast_logical(_op.logical_xor),
"broadcast_to" : _mx_broadcast_to,
"logical_not" : _mx_logical_not,
"_equal" : _mx_compare(_op.equal, _rename),
"_not_equal" : _mx_compare(_op.not_equal, _rename),
"_greater" : _mx_compare(_op.greater, _rename),
......@@ -1860,6 +1895,7 @@ _convert_map = {
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
"broadcast_axes": _mx_broadcast_axis,
"BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
......@@ -1897,7 +1933,6 @@ _convert_map = {
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
# "contrib_fifo_buffer": _mx_contrib_fifo_buffer,
"ring_buffer": _mx_contrib_fifo_buffer,
# Qnn ops
......
......@@ -301,11 +301,25 @@ def _mx_symbol(F, op_name, inputs):
return op(*inputs)
def test_forward_broadcast_ops():
for op in ["broadcast_add", "broadcast_sub", "broadcast_mul",
"broadcast_div", "broadcast_mod", "broadcast_maximum",
"broadcast_minimum", "broadcast_equal", "broadcast_not_equal",
"broadcast_greater", "broadcast_greater_equal",
"broadcast_lesser", "broadcast_lesser_equal"]:
for op in ["broadcast_add",
"broadcast_plus",
"broadcast_sub",
"broadcast_minus",
"broadcast_mul",
"broadcast_div",
"broadcast_mod",
"broadcast_maximum",
"broadcast_minimum",
"broadcast_equal",
"broadcast_not_equal",
"broadcast_greater",
"broadcast_greater_equal",
"broadcast_lesser",
"broadcast_lesser_equal",
"broadcast_power",
"broadcast_logical_or",
"broadcast_logical_and",
"broadcast_logical_xor"]:
a_shape = (3, 4, 5)
b_shape = (4, 5)
if op == "broadcast_mod":
......@@ -462,16 +476,51 @@ def test_forward_squeeze():
def test_forward_broadcast_axis():
def verify(shape, axis, size):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size)
mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for op in ["broadcast_axis",
"broadcast_axes"]:
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('x'),axis,size])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(x_np),axis,size])
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 2, 1), 2, 3)
verify((1, 2, 1), (0, 2), (2, 3))
def test_forward_broadcast_to():
def verify(input_shape, shape):
x_np = np.random.uniform(size=input_shape).astype("float32")
ref_res = mx.nd.broadcast_to(mx.nd.array(x_np), shape=shape)
mx_sym = mx.sym.broadcast_to(mx.sym.var("x"), shape=shape)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 2, 1), 2, 3)
verify((1, 2, 1), (0, 2), (2, 3))
verify((1, 2, 3), (3, 2, 3))
verify((4, 1, 32, 32), (4, 8, 32, 32))
def test_forward_logical_not():
a_shape = (3, 4, 5)
dtype = 'float32'
a_np = np.random.uniform(size=a_shape).astype(dtype)
mx_sym = mx.sym.logical_not(mx.sym.var('a'))
ref_res = mx.nd.logical_not(mx.nd.array(a_np))
shapes = {'a': a_shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
def test_forward_full():
def verify(val, shape, dtype):
......@@ -1061,6 +1110,8 @@ if __name__ == '__main__':
test_forward_where()
test_forward_arange()
test_forward_broadcast_ops()
test_forward_broadcast_to()
test_forward_logical_not()
test_forward_elemwise_ops()
test_forward_scalar_ops()
test_forward_slice_like()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment