Commit f9e8c116 by Haichen Shen Committed by Jared Roesch

[Frontend][MXNet] Fix mxnet converter for hybridblock and add div_sqrt_dim (#3701)

* Fix mxnet converter for hybrid block

* tweak

* fix rebase

* fix

* add test
parent 5416d1e4
......@@ -715,7 +715,7 @@ def _mx_topk(inputs, attrs):
return _op.topk(inputs[0], **new_attrs)
def _mx_SequenceMask(inputs, attrs):
def _mx_sequence_mask(inputs, attrs):
assert len(inputs) == 1 or len(inputs) == 2
new_attrs = {}
use_sequence_length = attrs.get_bool('use_sequence_length', False)
......@@ -727,6 +727,15 @@ def _mx_SequenceMask(inputs, attrs):
return inputs[0]
def _mx_contrib_div_sqrt_dim(inputs, _):
assert len(inputs) == 1
ndim = len(_infer_type(inputs[0]).checked_type.shape)
dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
sqrt_dim = _op.sqrt(dim.astype('float32'))
out = inputs[0] / sqrt_dim
return out
def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs]
......@@ -1014,11 +1023,12 @@ _convert_map = {
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
"SequenceMask" : _mx_SequenceMask,
"SequenceMask" : _mx_sequence_mask,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
......@@ -1183,8 +1193,10 @@ def from_mxnet(symbol,
params = {}
for k, v in symbol.collect_params().items():
params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data")
sym = symbol(data)
inputs = []
for name in shape:
inputs.append(mx.sym.Variable(name))
sym = symbol(*inputs)
if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params)
......
......@@ -714,6 +714,19 @@ def test_forward_sequence_mask():
verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64')
verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32')
def test_forward_contrib_div_sqrt_dim():
def verify(shape):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.contrib.div_sqrt_dim(mx.nd.array(x_np))
mx_sym = mx.sym.contrib.div_sqrt_dim(mx.sym.var("x"))
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4))
verify((3, 4, 5))
if __name__ == '__main__':
test_forward_mlp()
......@@ -759,3 +772,4 @@ if __name__ == '__main__':
test_forward_argsort()
test_forward_topk()
test_forward_sequence_mask()
test_forward_contrib_div_sqrt_dim()
......@@ -38,6 +38,7 @@ def schedule_batch_matmul(outs):
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(op):
......@@ -49,6 +50,9 @@ def schedule_batch_matmul(outs):
BB = s.cache_read(B, "shared", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")
if op not in s.outputs:
s[C].compute_inline()
C = s.outputs[0].output(0)
b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 64)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment