Commit 6514849f by Siva Committed by Tianqi Chen

[NNVM/TOPI][OP] Split : default axis to 0 and allow negative values - nump… (#1883)

parent 7631873b
...@@ -43,7 +43,7 @@ struct SplitParam : public dmlc::Parameter<SplitParam> { ...@@ -43,7 +43,7 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
DMLC_DECLARE_PARAMETER(SplitParam) { DMLC_DECLARE_PARAMETER(SplitParam) {
DMLC_DECLARE_FIELD(indices_or_sections) DMLC_DECLARE_FIELD(indices_or_sections)
.describe("Number of outputs to be splitted"); .describe("Number of outputs to be splitted");
DMLC_DECLARE_FIELD(axis).set_lower_bound(0).set_default(1) DMLC_DECLARE_FIELD(axis).set_default(1)
.describe("the axis to be splitted."); .describe("the axis to be splitted.");
} }
}; };
......
...@@ -344,14 +344,23 @@ inline bool SplitInferShape(const NodeAttrs& attrs, ...@@ -344,14 +344,23 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
const TShape& dshape = (*in_shape)[0]; const TShape& dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false; if (dshape.ndim() == 0) return false;
auto axis = param.axis;
if (axis < 0) {
axis += dshape.ndim();
}
CHECK_LT(axis, dshape.ndim())
<< "axis should be within input dimension range but got " << axis;
CHECK_GT(axis, -1)
<< "axis should be within input dimension range but got " << axis;
if (param.equal_split) { if (param.equal_split) {
int num_outputs = param.indices_or_sections[0]; int num_outputs = param.indices_or_sections[0];
CHECK_EQ(out_shape->size(), static_cast<size_t>(num_outputs)); CHECK_EQ(out_shape->size(), static_cast<size_t>(num_outputs));
CHECK_LT(param.axis, dshape.ndim());
TShape oshape = dshape; TShape oshape = dshape;
CHECK_EQ(oshape[param.axis] % num_outputs, 0) CHECK_EQ(oshape[axis] % num_outputs, 0)
<< "indices_or_sections need to be able to divide input.shape[axis]"; << "indices_or_sections need to be able to divide input.shape[axis] got sections "
oshape[param.axis] /= num_outputs; << num_outputs << " and dimension " << oshape[axis];
oshape[axis] /= num_outputs;
for (size_t i = 0; i < out_shape->size(); ++i) { for (size_t i = 0; i < out_shape->size(); ++i) {
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i, oshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i, oshape);
...@@ -359,19 +368,19 @@ inline bool SplitInferShape(const NodeAttrs& attrs, ...@@ -359,19 +368,19 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
} else { } else {
dim_t num_outputs = param.indices_or_sections.ndim() + 1; dim_t num_outputs = param.indices_or_sections.ndim() + 1;
CHECK_EQ(out_shape->size(), static_cast<size_t>(num_outputs)); CHECK_EQ(out_shape->size(), static_cast<size_t>(num_outputs));
CHECK_LT(param.axis, dshape.ndim());
TShape oshape = dshape; TShape oshape = dshape;
dim_t begin = 0; dim_t begin = 0;
for (dim_t i = 0; i < num_outputs - 1; ++i) { for (dim_t i = 0; i < num_outputs - 1; ++i) {
CHECK_GT(param.indices_or_sections[i], begin) CHECK_GT(param.indices_or_sections[i], begin)
<< "indices_or_sections need to be a sorted ascending list"; << "indices_or_sections need to be a sorted ascending list got "
oshape[param.axis] = param.indices_or_sections[i] - begin; << param.indices_or_sections;
oshape[axis] = param.indices_or_sections[i] - begin;
begin = param.indices_or_sections[i]; begin = param.indices_or_sections[i];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i, oshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i, oshape);
} }
CHECK_LT(begin, dshape[param.axis]) CHECK_LT(begin, dshape[axis])
<< "The sum of sections must match the input.shape[axis]"; << "The sum of sections must match the input.shape[axis]";
oshape[param.axis] = dshape[param.axis] - begin; oshape[axis] = dshape[axis] - begin;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, num_outputs - 1, oshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, num_outputs - 1, oshape);
} }
return true; return true;
......
...@@ -84,6 +84,10 @@ def test_split(): ...@@ -84,6 +84,10 @@ def test_split():
sdict = infer_shape(z) sdict = infer_shape(z)
assert(sdict["y"][0] == [10, 10]) assert(sdict["y"][0] == [10, 10])
assert(sdict["y"][1] == [10, 10]) assert(sdict["y"][1] == [10, 10])
z = sym.split(x1, indices_or_sections=[6], axis=-1, name="y")
sdict = infer_shape(z)
assert(sdict["y"][0] == [10, 6])
assert(sdict["y"][1] == [10, 14])
def test_batchnorm(): def test_batchnorm():
......
...@@ -4,7 +4,6 @@ from __future__ import absolute_import as _abs ...@@ -4,7 +4,6 @@ from __future__ import absolute_import as _abs
import tvm import tvm
import topi import topi
from . import tag from . import tag
from .util import ravel_index, unravel_index, get_const_int, get_const_tuple
from . import cpp from . import cpp
@tvm.tag_scope(tag=tag.BROADCAST) @tvm.tag_scope(tag=tag.BROADCAST)
...@@ -23,12 +22,7 @@ def expand_dims(a, axis, num_newaxis=1): ...@@ -23,12 +22,7 @@ def expand_dims(a, axis, num_newaxis=1):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
axis = len(a.shape) + axis + 1 if axis < 0 else axis return cpp.expand_dims(a, axis, num_newaxis)
new_shape = a.shape[:axis] + ([1] * num_newaxis) + a.shape[axis:]
def _compute(*indices):
idx = indices[:axis] + indices[axis + num_newaxis:]
return a(*idx)
return tvm.compute(new_shape, _compute)
@tvm.tag_scope(tag=tag.BROADCAST) @tvm.tag_scope(tag=tag.BROADCAST)
...@@ -101,15 +95,8 @@ def transpose(a, axes=None): ...@@ -101,15 +95,8 @@ def transpose(a, axes=None):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
ndim = len(a.shape) return cpp.transpose(a, axes)
axes = axes if axes else tuple(reversed(range(ndim)))
new_shape = [a.shape[x] for x in axes]
def _compute(*indices):
idx = [1] * len(axes)
for i, k in enumerate(axes):
idx[k] = indices[i]
return a(*idx)
return tvm.compute(new_shape, _compute)
def flip(a, axis=0): def flip(a, axis=0):
"""Flip/reverse elements of an array in a particular axis. """Flip/reverse elements of an array in a particular axis.
...@@ -153,6 +140,7 @@ def strided_slice(a, begin, end, strides=None): ...@@ -153,6 +140,7 @@ def strided_slice(a, begin, end, strides=None):
""" """
return cpp.strided_slice(a, begin, end, strides) return cpp.strided_slice(a, begin, end, strides)
@tvm.tag_scope(tag=tag.INJECTIVE) @tvm.tag_scope(tag=tag.INJECTIVE)
def reshape(a, newshape): def reshape(a, newshape):
"""Reshape the array """Reshape the array
...@@ -168,10 +156,7 @@ def reshape(a, newshape): ...@@ -168,10 +156,7 @@ def reshape(a, newshape):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
ndim = len(a.shape) return cpp.reshape(a, newshape)
a_shape = [a.shape[i] for i in range(ndim)]
return tvm.compute(newshape,
lambda *indices: a(*unravel_index(ravel_index(indices, newshape), a_shape)))
@tvm.tag_scope(tag=tag.INJECTIVE) @tvm.tag_scope(tag=tag.INJECTIVE)
...@@ -190,41 +175,7 @@ def squeeze(a, axis=None): ...@@ -190,41 +175,7 @@ def squeeze(a, axis=None):
------- -------
squeezed : tvm.Tensor squeezed : tvm.Tensor
""" """
a_ndim = len(a.shape) return cpp.squeeze(a, axis)
a_shape = get_const_tuple(a.shape)
if axis is None:
axis = []
for i, ele in enumerate(a_shape):
if ele == 1:
axis.append(i)
else:
if isinstance(axis, int):
axis = axis + a_ndim if axis < 0 else axis
assert a_shape[axis] == 1
axis = [axis]
else:
axis = [ele + a_ndim if ele < 0 else ele for ele in axis]
for ele in axis:
assert a_shape[ele] == 1
out_shape = []
search_axis = set(axis)
for i, a_dim in enumerate(a_shape):
if i not in search_axis:
out_shape.append(a_dim)
if not out_shape:
out_shape.append(1)
def _compute(*indices):
real_indices = []
flag = 0
for i in range(a_ndim):
if i not in search_axis:
real_indices.append(indices[i - flag])
else:
real_indices.append(0)
flag += 1
return a(*real_indices)
return tvm.compute(out_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE) @tvm.tag_scope(tag=tag.INJECTIVE)
...@@ -243,25 +194,7 @@ def concatenate(a_tuple, axis=0): ...@@ -243,25 +194,7 @@ def concatenate(a_tuple, axis=0):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
assert isinstance(a_tuple, (list, tuple)) return cpp.concatenate(a_tuple, axis)
if axis < 0:
axis += len(a_tuple[0].shape)
assert axis < len(a_tuple[0].shape)
axis_sizes = [a_tuple[i].shape[axis] for i in range(len(a_tuple))]
out_shape = [a_tuple[0].shape[i] for i in range(0, axis)] + [sum(axis_sizes)]\
+ [a_tuple[0].shape[i] for i in range(axis + 1, len(a_tuple[0].shape))]
out_shape[axis] = tvm.ir_pass.Simplify(out_shape[axis])
def _compute(*indices):
ret = a_tuple[0](*indices)
ind = indices[axis]
for i in range(len(a_tuple) - 1):
ind -= axis_sizes[i]
ret = tvm.select(ind >= 0,
a_tuple[i + 1](*(indices[0:axis] + (ind,) + indices[axis + 1:])),
ret)
return ret
return tvm.compute(out_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE) @tvm.tag_scope(tag=tag.INJECTIVE)
...@@ -280,37 +213,7 @@ def split(ary, indices_or_sections, axis=0): ...@@ -280,37 +213,7 @@ def split(ary, indices_or_sections, axis=0):
------- -------
ret : tuple of tvm.Tensor ret : tuple of tvm.Tensor
""" """
def _compute(begin, *indices): return cpp.split(ary, indices_or_sections, axis)
real_indices = indices[:axis] + (indices[axis] + begin, ) + indices[axis + 1:]
return ary(*real_indices)
if axis < 0:
axis += len(ary.shape)
src_axis_size = get_const_int(ary.shape[axis])
if isinstance(indices_or_sections, int):
assert indices_or_sections > 0
assert src_axis_size % indices_or_sections == 0
seg_size = src_axis_size // indices_or_sections
begin_ids = [seg_size * i for i in range(indices_or_sections)]
elif isinstance(indices_or_sections, (tuple, list)):
assert tuple(indices_or_sections) == tuple(sorted(indices_or_sections)),\
"Should be sorted, recieved %s" % str(indices_or_sections)
begin_ids = [0] + list(indices_or_sections)
else:
raise NotImplementedError()
out_shapes = []
for i in range(len(begin_ids)):
if i == len(begin_ids) - 1:
out_axis_size = src_axis_size - begin_ids[i]
else:
out_axis_size = begin_ids[i + 1] - begin_ids[i]
out_shapes.append([ary.shape[i] for i in range(axis)] + [out_axis_size] +\
[ary.shape[i] for i in range(axis + 1, len(ary.shape))])
# pylint: disable=cell-var-from-loop
return [tvm.compute(out_shape,
lambda *indices: _compute(begin_id, *indices), name="s%d" %i)
for i, (out_shape, begin_id) in enumerate(zip(out_shapes, begin_ids))]
# pylint: enable=cell-var-from-loop
def take(a, indices, axis=None): def take(a, indices, axis=None):
...@@ -336,6 +239,7 @@ def take(a, indices, axis=None): ...@@ -336,6 +239,7 @@ def take(a, indices, axis=None):
return cpp.take(a, indices) return cpp.take(a, indices)
return cpp.take(a, indices, int(axis)) return cpp.take(a, indices, int(axis))
def matmul(a, b, transp_a=False, transp_b=False): def matmul(a, b, transp_a=False, transp_b=False):
""" """
Creates an operation that calculates a matrix multiplication (row-major notation): Creates an operation that calculates a matrix multiplication (row-major notation):
......
...@@ -139,7 +139,7 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -139,7 +139,7 @@ def verify_split(src_shape, indices_or_sections, axis):
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(tensor_l) s = topi.generic.schedule_injective(tensor_l)
foo = tvm.build(s, [A] + tensor_l, device, name="split") foo = tvm.build(s, [A] + list(tensor_l), device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npys = np.split(data_npy, indices_or_sections, axis=axis) out_npys = np.split(data_npy, indices_or_sections, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment