Commit fe0eac94 by Siju Committed by Tianqi Chen

[RELAY]take and transpose comp and schd (#2135)

parent bcacb764
...@@ -874,7 +874,7 @@ Examples:: ...@@ -874,7 +874,7 @@ Examples::
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed); const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
auto axes = ShapeToArray(param.axes); auto axes = ShapeToIntArray(param.axes);
return Array<Tensor>{ topi::transpose(inputs[0], axes) }; return Array<Tensor>{ topi::transpose(inputs[0], axes) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
......
...@@ -15,3 +15,5 @@ _reg.register_schedule("cast", schedule_broadcast) ...@@ -15,3 +15,5 @@ _reg.register_schedule("cast", schedule_broadcast)
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective) _reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
...@@ -282,6 +282,15 @@ bool TransposeRel(const Array<Type>& types, ...@@ -282,6 +282,15 @@ bool TransposeRel(const Array<Type>& types,
return true; return true;
} }
Array<Tensor> TransposeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<TransposeAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{ topi::transpose(inputs[0], param->axes) };
}
Expr MakeTranspose(Expr data, Expr MakeTranspose(Expr data,
Array<Integer> axes) { Array<Integer> axes) {
auto attrs = make_node<TransposeAttrs>(); auto attrs = make_node<TransposeAttrs>();
...@@ -307,7 +316,9 @@ RELAY_REGISTER_OP("transpose") ...@@ -307,7 +316,9 @@ RELAY_REGISTER_OP("transpose")
.set_attrs_type_key("relay.attrs.TransposeAttrs") .set_attrs_type_key("relay.attrs.TransposeAttrs")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Transpose", TransposeRel); .add_type_rel("Transpose", TransposeRel)
.set_attr<FTVMCompute>("FTVMCompute", TransposeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
/* relay.reshape */ /* relay.reshape */
...@@ -575,6 +586,19 @@ bool TakeRel(const Array<Type>& types, ...@@ -575,6 +586,19 @@ bool TakeRel(const Array<Type>& types,
return true; return true;
} }
Array<Tensor> TakeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
if (!param->axis.defined()) {
return Array<Tensor>{ topi::take(inputs[0], inputs[1]) };
} else {
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis) };
}
}
Expr MakeTake(Expr data, Expr MakeTake(Expr data,
Expr indices, Expr indices,
Integer axis) { Integer axis) {
...@@ -617,7 +641,10 @@ Examples:: ...@@ -617,7 +641,10 @@ Examples::
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.") .add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(2) .set_support_level(2)
.add_type_rel("Take", TakeRel); .add_type_rel("Take", TakeRel)
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// Init ops // Init ops
TVM_REGISTER_NODE_TYPE(InitOpAttrs); TVM_REGISTER_NODE_TYPE(InitOpAttrs);
......
...@@ -87,6 +87,22 @@ def test_transpose_infer_type(): ...@@ -87,6 +87,22 @@ def test_transpose_infer_type():
assert yy.checked_type == relay.TensorType( assert yy.checked_type == relay.TensorType(
(t, n, 100), "float32") (t, n, 100), "float32")
def test_transpose():
def verify_transpose(dshape, axes):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.transpose(x, axes=axes)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.transpose(x_data, axes=axes)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_transpose((2, 3, 4), (0, 2, 1))
def test_squeeze_infer_type(): def test_squeeze_infer_type():
n, t, d = 1, 4, 1 n, t, d = 1, 4, 1
...@@ -202,6 +218,35 @@ def test_take_infer_type(): ...@@ -202,6 +218,35 @@ def test_take_infer_type():
verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1) verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1)
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
def test_take():
def verify_take(src_shape, indices_src, axis=None):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
x = relay.var("x", relay.TensorType(src_shape, src_dtype))
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
z = relay.take(x, indices, axis=axis)
func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, indices_src)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_take((4,), [1])
verify_take((4,), [[0,1,2,3]])
verify_take((3,3,3), [[11,25]])
verify_take((4,), [[0,1],[2,3]])
verify_take((4,), [1], 0)
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
def test_split_infer_type(): def test_split_infer_type():
def verify_split(dshape, indices_or_sections, ret_type, axis=None): def verify_split(dshape, indices_or_sections, ret_type, axis=None):
x = relay.var("x", relay.ty.TensorType(dshape, "float32")) x = relay.var("x", relay.ty.TensorType(dshape, "float32"))
...@@ -360,11 +405,13 @@ if __name__ == "__main__": ...@@ -360,11 +405,13 @@ if __name__ == "__main__":
test_unary_identity() test_unary_identity()
test_clip() test_clip()
test_transpose_infer_type() test_transpose_infer_type()
test_transpose()
test_reshape_infer_type() test_reshape_infer_type()
test_reshape() test_reshape()
test_reshape_like_infer_type() test_reshape_like_infer_type()
test_reshape_like() test_reshape_like()
test_take_infer_type() test_take_infer_type()
test_take()
test_full() test_full()
test_full_like() test_full_like()
test_infer_type_leaky_relu() test_infer_type_leaky_relu()
......
...@@ -86,42 +86,45 @@ inline Tensor expand_dims(const Tensor& x, ...@@ -86,42 +86,45 @@ inline Tensor expand_dims(const Tensor& x,
* \return A Tensor whose op member is the transpose operation * \return A Tensor whose op member is the transpose operation
*/ */
inline Tensor transpose(const Tensor& x, inline Tensor transpose(const Tensor& x,
Array<Expr> axes, Array<Integer> axes,
std::string name = "tensor", std::string name = "tensor",
std::string tag = kInjective) { std::string tag = kInjective) {
if (axes.size() == 0) { if (!axes.defined() || axes.size() == 0) {
axes = Array<Expr>(); axes = Array<Integer>();
for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
axes.push_back(i); axes.push_back(i);
} }
} }
auto axes_val = GetConstIntValues(axes, "axes"); Array<Expr> new_shape;
for (size_t i = 0; i < axes_val.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes_val[i]; int axis = static_cast<int>(axes[i]->value);
if (axes_val[i] < 0) { int new_axis = axis;
axes_val[i] = static_cast<int>(x->shape.size()) + axes_val[i]; if (axis < 0) {
new_axis = static_cast<int>(x->shape.size()) + axis;
axes.Set(i, new_axis);
} }
CHECK((0 <= axes_val[i]) && (axes_val[i] < static_cast<int>(x->shape.size()))) CHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
<< "axis=" << axis << " is invalid for the " << "axis=" << axis << " is invalid for the "
<< static_cast<int>(x->shape.size()) << "-dimensional input tensor"; << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
CHECK(1 == std::count(std::begin(axes_val), std::end(axes_val), axes_val[i])) for (size_t j = 0; j < axes.size(); ++j) {
<< "repeated axis in transpose"; if (i !=j) {
CHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
} }
Array<Expr> new_shape;
for (size_t i = 0; i < axes_val.size(); ++i) {
new_shape.push_back(x->shape[axes_val[i]]);
} }
new_shape.push_back(x->shape[new_axis]);
}
return compute( return compute(
new_shape, [&](const Array<Var>& indices) { new_shape, [&](const Array<Var>& indices) {
std::vector<Expr> idx; std::vector<Expr> idx;
for (size_t i = 0; i < axes_val.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
idx.push_back(1); idx.push_back(1);
} }
for (size_t i = 0; i < axes_val.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
idx[axes_val[i]] = indices[i]; int axis = static_cast<int>(axes[i]->value);
idx[axis] = indices[i];
} }
return x(idx); return x(idx);
}, name, tag); }, name, tag);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment