Commit d39a4ea0 by Haichen Shen Committed by Tianqi Chen

Add MXNet converter for RNN layer ops (#3125)

parent 2ed7f95a
......@@ -26,6 +26,7 @@ from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from . import expr as _expr
from . import ty as _ty
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
......@@ -427,6 +428,8 @@ class GraphExecutor(_interpreter.Executor):
self.target = target
def _make_executor(self, func):
ret_type = ir_pass.infer_type(func).ret_type
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(func, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
......@@ -440,7 +443,12 @@ class GraphExecutor(_interpreter.Executor):
# Run the module, and fetch the output.
gmodule.run()
# make a copy so multiple invocation won't hurt perf.
if num_outputs == 1:
return gmodule.get_output(0).copyto(_nd.cpu(0))
outputs = []
for i in range(num_outputs):
outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
return outputs
return _graph_wrapper
......
......@@ -34,6 +34,12 @@ from .nnvm_common import _warn_not_used
__all__ = ['from_mxnet']
_activation_map = {
"sigmoid": _op.sigmoid,
"tanh" : _op.tanh,
"relu" : _op.nn.relu
}
def _mx_fully_connected(inputs, attrs):
import mxnet as mx
units = attrs.get_int("num_hidden")
......@@ -66,12 +72,6 @@ def _get_channel_axis(layout, op_name):
def _mx_activations(inputs, attrs):
act_type = attrs.get_str("act_type")
assert len(inputs) == 1
if act_type == "sigmoid":
return _op.sigmoid(inputs[0])
if act_type == "tanh":
return _op.tanh(inputs[0])
if act_type == "relu":
return _op.nn.relu(inputs[0])
if act_type == "softrelu":
def _stable_softrelu(x):
# log(1 + exp(-abs(x))) + relu(x)
......@@ -80,8 +80,10 @@ def _mx_activations(inputs, attrs):
return _op.add(_op.log(_op.add(one, exp_neg_abs_x)),
_op.nn.relu(x))
return _stable_softrelu(inputs[0])
if act_type not in _activation_map:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend MXNet.'.format(act_type))
return _activation_map[act_type](inputs[0])
def _mx_compare(new_op, wrapper):
......@@ -189,7 +191,8 @@ def _mx_pooling(inputs, attrs):
def _mx_adaptive_avg_pooling(inputs, attrs):
output_size = attrs.get_int_tuple("output_size", [])
if output_size != (1,):
raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
raise tvm.error.OpAttributeUnimplemented(
"AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
return _op.nn.global_avg_pool2d(inputs[0])
......@@ -471,7 +474,7 @@ def _mx_take(inputs, attrs):
assert len(inputs) == 2
mode = attrs.get_str("mode", "clip")
if mode == "raise":
raise RuntimeError("take doesn't support raise mode")
raise tvm.error.OpAttributeUnimplemented("take with raise mode is not supported yet")
axis = attrs.get_int("axis", 0)
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
......@@ -571,13 +574,13 @@ def _mx_l2_normalize(inputs, attrs):
def _mx_shape_array(inputs, attrs):
assert len(inputs) == 1
if attrs.get_int("lhs_begin", None) is not None:
raise RuntimeError("shape_array doesn't support lhs_begin")
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_begin")
if attrs.get_int("lhs_end", None) is not None:
raise RuntimeError("shape_array doesn't support lhs_end")
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_end")
if attrs.get_int("rhs_begin", None) is not None:
raise RuntimeError("shape_array doesn't support rhs_begin")
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_begin")
if attrs.get_int("rhs_end", None) is not None:
raise RuntimeError("shape_array doesn't support rhs_end")
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_end")
return _op.shape_of(inputs[0], dtype='int64')
......@@ -657,6 +660,101 @@ def _mx_argsort(inputs, attrs):
return _op.argsort(inputs[0], **new_attrs)
def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs]
def _mx_rnn_layer(inputs, attrs):
def _rnn_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias, activation):
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
out = _activation_map[activation](i2h + h2h)
return out, [out]
def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
dtype = ir_pass.infer_type(data).checked_type.dtype
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1)
h2h_r, h2h_z, h2h = _op.split(h2h, indices_or_sections=3, axis=1)
reset_gate = _activation_map["sigmoid"](i2h_r + h2h_r)
update_gate = _activation_map["sigmoid"](i2h_z + h2h_z)
next_h_tmp = _activation_map["tanh"](reset_gate * h2h + i2h)
next_h = (_expr.const(1, dtype) - update_gate) * next_h_tmp + update_gate * states[0]
return next_h, [next_h]
def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
gates = i2h + h2h
slice_gates = _op.split(gates, indices_or_sections=4, axis=1)
in_gate = _activation_map["sigmoid"](slice_gates[0])
forget_gate = _activation_map["sigmoid"](slice_gates[1])
in_transform = _activation_map["tanh"](slice_gates[2])
out_gate = _activation_map["sigmoid"](slice_gates[3])
next_c = forget_gate * states[1] + in_gate * in_transform
next_h = out_gate * _activation_map["tanh"](next_c)
return next_h, [next_h, next_c]
num_layers = attrs.get_int("num_layers", 1)
mode = attrs.get_str("mode")
if mode.startswith("rnn"):
mode, activation = mode.split('_')
assert mode in ["rnn", "gru", "lstm"]
bidirectional = attrs.get_bool("bidirectional", False)
if bidirectional:
raise tvm.error.OpAttributeUnimplemented(
"Bidirectional RNN op is not supported yet")
layout = attrs.get_str("layout", "TNC")
if layout != "TNC":
raise tvm.error.OpAttributeUnimplemented(
"RNN with layout other than TNC is not supported yet")
num_states = 2 if mode == 'lstm' else 1
assert len(inputs) == num_states + 2
seq_data = inputs[0]
concat_weight = inputs[1]
concat_states = inputs[2:]
seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0])
assert len(concat_weight) == num_layers * 4
weights = []
bias = []
states = []
for i in range(num_layers):
w = []
b = []
s = []
for j in range(2):
w.append(concat_weight[i*2 + j].args[0])
b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
for state in concat_states:
s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
weights.append(w)
bias.append(b)
states.append(s)
seq_output = []
for t in range(seq_len):
data = _op.take(seq_data, _expr.const(t, "int32"), axis=0)
for l in range(num_layers):
if mode == "rnn":
out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation)
elif mode == "gru":
out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l])
else: # mode == "lstm"
out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l])
states[l] = new_states
data = out
seq_output.append(out)
outputs = [_op.stack(seq_output, axis=0)]
for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
return outputs
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -807,6 +905,9 @@ _convert_map = {
"_contrib_box_nms" : _mx_box_nms,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
"_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
......
......@@ -527,6 +527,54 @@ def test_forward_bilinear_resize():
mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10)
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
def test_forward_rnn_layer():
def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
if mode == "rnn":
layer = gluon.rnn.RNN(hidden_size, num_layers)
elif mode == "gru":
layer = gluon.rnn.GRU(hidden_size, num_layers)
else: # mode == "lstm"
layer = gluon.rnn.LSTM(hidden_size, num_layers)
num_states = 2 if mode == "lstm" else 1
layer.initialize()
dtype = "float32"
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
states_np = []
states_mx = []
shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np}
for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
inputs['data%s' % (i+1)] = s
layer.hybridize()
mx_out, mx_states = layer(mx.nd.array(data_np), states_mx)
mx_res = [mx_out] + mx_states
mx_sym = layer._cached_graph[1]
mx_params = {}
for name, param in layer.collect_params().items():
mx_params[name] = param._reduce()
new_sym, params = relay.frontend.from_mxnet(
mx_sym, shape=shape_dict, arg_params=mx_params)
for target, ctx in ctx_list():
# only test graph runtime because debug runtime is too slow
for kind in ["graph"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(**inputs, **params)
assert len(op_res) == len(mx_res)
for i, val in enumerate(op_res):
tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
for mode in ["rnn", "gru", "lstm"]:
verify(mode, 64, 10, 64, 1)
verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2)
if __name__ == '__main__':
test_forward_mlp()
......@@ -566,3 +614,4 @@ if __name__ == '__main__':
test_forward_take()
test_forward_gather_nd()
test_forward_bilinear_resize()
test_forward_rnn_layer()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment