Commit 9b0e4990 by Sergey Mironov Committed by Tianqi Chen

[NNVM] TF: Add Pack operation (#1570)

parent 6cd5a8f9
......@@ -16,7 +16,7 @@ namespace top {
struct ConcatenateParam : public dmlc::Parameter<ConcatenateParam> {
int axis;
DMLC_DECLARE_PARAMETER(ConcatenateParam) {
DMLC_DECLARE_FIELD(axis).set_lower_bound(0).set_default(1)
DMLC_DECLARE_FIELD(axis).set_default(1)
.describe("the axis to be concated.");
}
};
......
......@@ -339,6 +339,14 @@ def _concat():
extras={'axis': axis.asnumpy()[0]})(inputs, attr)
return _impl
def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
return _sym.concatenate(*inputs_reshaped, axis=axis)
return _impl
def _reshape():
def _impl(inputs, attr, params):
try:
......@@ -673,6 +681,7 @@ _convert_map = {
'Minimum' : _elemwise('min'),
'Sum' : _sum(),
'Square' : _square(),
'Pack' : _pack(),
'Relu' : AttrCvt('relu'),
'Reshape' : _reshape(),
'ResizeBilinear' : _resize_bilinear(),
......
......@@ -93,23 +93,24 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs,
TShape dshape;
dim_t size = 0;
bool has_zero = false;
int axis = param.axis >= 0 ? param.axis : in_shape->at(0).ndim() + param.axis;
for (size_t i = 0; i < in_shape->size(); ++i) {
TShape tmp = (*in_shape)[i];
if (tmp.ndim()) {
CHECK_LT(static_cast<dim_t>(param.axis), tmp.ndim())
<< "concat dim " << param.axis << " out of range of input shape " << tmp;
has_zero = tmp[param.axis] == 0 || has_zero;
size += tmp[param.axis];
tmp[param.axis] = 0;
CHECK_LT(static_cast<dim_t>(axis), tmp.ndim())
<< "concat dim " << axis << " out of range of input shape " << tmp;
has_zero = tmp[axis] == 0 || has_zero;
size += tmp[axis];
tmp[axis] = 0;
shape_assign(&dshape, tmp);
}
}
TShape tmp = (*out_shape)[0];
if (tmp.ndim()) {
CHECK_LT(static_cast<dim_t>(param.axis), tmp.ndim())
<< "concat dim " << param.axis << " out of range of input shape " << tmp;
tmp[param.axis] = 0;
CHECK_LT(static_cast<dim_t>(axis), tmp.ndim())
<< "concat dim " << axis << " out of range of input shape " << tmp;
tmp[axis] = 0;
shape_assign(&dshape, tmp);
}
......@@ -119,7 +120,7 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs,
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, i, dshape);
}
if (!has_zero) dshape[param.axis] = size;
if (!has_zero) dshape[axis] = size;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape);
return dshape.Size() != 0;
}
......
......@@ -342,7 +342,7 @@ def _test_argx(func, data, **kwargs):
compare_tf_with_tvm(data, 'c0:0', 'argx0:0')
def test_argmin_argmax():
def test_forward_argminmax():
for axis in [None,0,1,2]:
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_argx(tf.argmax, data=data, axis=axis)
......@@ -555,6 +555,31 @@ def test_forward_lstm():
_test_lstm_cell(1, 2, 1, 0.0, 'float32')
#######################################################################
# Pack
# ---
def _test_pack(axis, shape, **kwargs):
a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
b = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
with tf.Graph().as_default():
tf_a = array_ops.placeholder(shape=shape, dtype='float32', name='pl_a')
tf_b = array_ops.placeholder(shape=shape, dtype='float32', name='pl_b')
tf_c = tf.stack([tf_a,tf_b], axis=axis, **kwargs)
assert tf_c.op.op_def.name == 'Pack', "tf.stack() is expected to produce 'Pack' operation"
compare_tf_with_tvm([a,b], ['pl_a:0','pl_b:0'], 'stack:0')
def test_forward_pack():
for axis in range(-3,3):
_test_pack(axis, [3,2,1])
for axis in range(-1,1):
_test_pack(axis, [3])
_test_pack(0, [])
#######################################################################
# Pad
# ---
......@@ -818,9 +843,11 @@ if __name__ == '__main__':
test_forward_reshape()
test_forward_squeeze()
test_forward_sigmoid()
test_forward_argminmax()
if tf.__version__ == '1.4.1':
_test_forward_concat_v2()
test_forward_multi_input()
test_forward_pack()
test_forward_inception_v3()
test_forward_inception_v1()
test_forward_mobilenet()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment