Commit fe0eac94 by Siju Committed by Tianqi Chen

[RELAY]take and transpose comp and schd (#2135)

parent bcacb764
......@@ -874,7 +874,7 @@ Examples::
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
auto axes = ShapeToArray(param.axes);
auto axes = ShapeToIntArray(param.axes);
return Array<Tensor>{ topi::transpose(inputs[0], axes) };
})
.set_attr<FGradient>(
......
......@@ -15,3 +15,5 @@ _reg.register_schedule("cast", schedule_broadcast)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
......@@ -282,6 +282,15 @@ bool TransposeRel(const Array<Type>& types,
return true;
}
Array<Tensor> TransposeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<TransposeAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{ topi::transpose(inputs[0], param->axes) };
}
Expr MakeTranspose(Expr data,
Array<Integer> axes) {
auto attrs = make_node<TransposeAttrs>();
......@@ -307,7 +316,9 @@ RELAY_REGISTER_OP("transpose")
.set_attrs_type_key("relay.attrs.TransposeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Transpose", TransposeRel);
.add_type_rel("Transpose", TransposeRel)
.set_attr<FTVMCompute>("FTVMCompute", TransposeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
/* relay.reshape */
......@@ -575,6 +586,19 @@ bool TakeRel(const Array<Type>& types,
return true;
}
Array<Tensor> TakeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
if (!param->axis.defined()) {
return Array<Tensor>{ topi::take(inputs[0], inputs[1]) };
} else {
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis) };
}
}
Expr MakeTake(Expr data,
Expr indices,
Integer axis) {
......@@ -617,7 +641,10 @@ Examples::
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(2)
.add_type_rel("Take", TakeRel);
.add_type_rel("Take", TakeRel)
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// Init ops
TVM_REGISTER_NODE_TYPE(InitOpAttrs);
......
......@@ -87,6 +87,22 @@ def test_transpose_infer_type():
assert yy.checked_type == relay.TensorType(
(t, n, 100), "float32")
def test_transpose():
def verify_transpose(dshape, axes):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.transpose(x, axes=axes)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.transpose(x_data, axes=axes)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_transpose((2, 3, 4), (0, 2, 1))
def test_squeeze_infer_type():
n, t, d = 1, 4, 1
......@@ -202,6 +218,35 @@ def test_take_infer_type():
verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1)
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
def test_take():
def verify_take(src_shape, indices_src, axis=None):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
x = relay.var("x", relay.TensorType(src_shape, src_dtype))
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
z = relay.take(x, indices, axis=axis)
func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, indices_src)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_take((4,), [1])
verify_take((4,), [[0,1,2,3]])
verify_take((3,3,3), [[11,25]])
verify_take((4,), [[0,1],[2,3]])
verify_take((4,), [1], 0)
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
def test_split_infer_type():
def verify_split(dshape, indices_or_sections, ret_type, axis=None):
x = relay.var("x", relay.ty.TensorType(dshape, "float32"))
......@@ -360,11 +405,13 @@ if __name__ == "__main__":
test_unary_identity()
test_clip()
test_transpose_infer_type()
test_transpose()
test_reshape_infer_type()
test_reshape()
test_reshape_like_infer_type()
test_reshape_like()
test_take_infer_type()
test_take()
test_full()
test_full_like()
test_infer_type_leaky_relu()
......
......@@ -86,42 +86,45 @@ inline Tensor expand_dims(const Tensor& x,
* \return A Tensor whose op member is the transpose operation
*/
inline Tensor transpose(const Tensor& x,
Array<Expr> axes,
Array<Integer> axes,
std::string name = "tensor",
std::string tag = kInjective) {
if (axes.size() == 0) {
axes = Array<Expr>();
if (!axes.defined() || axes.size() == 0) {
axes = Array<Integer>();
for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
axes.push_back(i);
}
}
auto axes_val = GetConstIntValues(axes, "axes");
for (size_t i = 0; i < axes_val.size(); ++i) {
int axis = axes_val[i];
if (axes_val[i] < 0) {
axes_val[i] = static_cast<int>(x->shape.size()) + axes_val[i];
Array<Expr> new_shape;
for (size_t i = 0; i < axes.size(); ++i) {
int axis = static_cast<int>(axes[i]->value);
int new_axis = axis;
if (axis < 0) {
new_axis = static_cast<int>(x->shape.size()) + axis;
axes.Set(i, new_axis);
}
CHECK((0 <= axes_val[i]) && (axes_val[i] < static_cast<int>(x->shape.size())))
CHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
<< "axis=" << axis << " is invalid for the "
<< static_cast<int>(x->shape.size()) << "-dimensional input tensor";
CHECK(1 == std::count(std::begin(axes_val), std::end(axes_val), axes_val[i]))
<< "repeated axis in transpose";
for (size_t j = 0; j < axes.size(); ++j) {
if (i !=j) {
CHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
}
Array<Expr> new_shape;
for (size_t i = 0; i < axes_val.size(); ++i) {
new_shape.push_back(x->shape[axes_val[i]]);
}
new_shape.push_back(x->shape[new_axis]);
}
return compute(
new_shape, [&](const Array<Var>& indices) {
std::vector<Expr> idx;
for (size_t i = 0; i < axes_val.size(); ++i) {
for (size_t i = 0; i < axes.size(); ++i) {
idx.push_back(1);
}
for (size_t i = 0; i < axes_val.size(); ++i) {
idx[axes_val[i]] = indices[i];
for (size_t i = 0; i < axes.size(); ++i) {
int axis = static_cast<int>(axes[i]->value);
idx[axis] = indices[i];
}
return x(idx);
}, name, tag);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment