Commit 7bc990ad by ziheng Committed by Tianqi Chen

[RELAY] Support concatenate. (#2298)

parent f55d9628
......@@ -364,6 +364,11 @@ def _squeeze(children, attrs, odtype='float32'):
return op.squeeze(children[0], axis)
def _concatenate(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', None)
return op.concatenate(children, axis)
NNVM_OP_2_RELAY_OP = {
'flatten': _nn_batch_flatten,
'dense': _dense,
......@@ -422,6 +427,7 @@ NNVM_OP_2_RELAY_OP = {
'strided_slice': _strided_slice,
'split': _split,
'squeeze': _squeeze,
'concatenate': _concatenate,
}
......@@ -436,7 +442,7 @@ def to_relay(graph, shape_dict, dtype_dict, params):
shape_dict : dict of str to shape
The input shape.
dtype_dict : dict of str to shape
dtype_dict : dict of str to str/dtype
The input shape.
params : dict of str to array
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment