Commit 160eaf54 by Junru Shao Committed by Tianqi Chen

[RELAY][OP] expand_dims (#1819)

parent 7f745fbe
......@@ -26,6 +26,7 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.sqrt
tvm.relay.exp
tvm.relay.add
tvm.relay.expand_dims
**Level 2: Convolutions**
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/transform.h
* \brief Transform operators.
*/
#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
#define TVM_RELAY_ATTRS_TRANSFORM_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief Attributes used in expand_dims operators */
struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
int axis;
int num_newaxis;
TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The axis at which the input array is expanded."
"Should lie in range `[-data.ndim - 1, data.ndim]`."
"If `axis < 0`, it is the first axis inserted;"
"If `axis >= 0`, it is the last axis inserted in Python's negative indexing.");
TVM_ATTR_FIELD(num_newaxis)
.describe("Number of axises to be inserted. Should be >= 0.")
.set_lower_bound(0)
.set_default(1);
}
}; // struct ExpandDimsAttrs
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
......@@ -11,6 +11,7 @@ from . import ir_builder
from .op import Op
from .op.tensor import *
from . import nn
from .op.transform import *
# Span
Span = base.Span
......
......@@ -6,6 +6,7 @@ from .op import get, register, Op
# Operators
from .tensor import *
from . import nn
from .transform import *
# operator registry
......
"""Transform operators."""
from . import _make
def expand_dims(data, axis, num_newaxis=1):
"""Insert `num_newaxis` axises at the position given by `axis`.
Parameters
----------
data : relay.Expr
The input data to the operator.
axis : int
The axis at which the input array is expanded.
Should lie in range `[-data.ndim - 1, data.ndim]`.
If `axis < 0`, it is the first axis inserted;
If `axis >= 0`, it is the last axis inserted in Python's negative indexing.
num_newaxis : int
Number of axises to be inserted. Should be >= 0.
Returns
-------
result : relay.Expr
The reshaped result.
"""
return _make.expand_dims(data, axis, num_newaxis)
/*!
* Copyright (c) 2018 by Contributors
* \file transform.cc
* \brief Transform operators.
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h>
#include <vector>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
bool ExpandDimsRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, output]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const ExpandDimsAttrs* param = attrs.as<ExpandDimsAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int axis = param->axis;
const int num_newaxis = param->num_newaxis;
CHECK(num_newaxis >= 0)
<< "expand_dims only accepts `num_newaxis >= 0`"
<< ", but got num_newaxis = " << num_newaxis;
CHECK(-ndim - 1 <= axis && axis <= ndim)
<< "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
<< ", but got axis = " << axis
<< ", and data.ndim = " << ndim;
const int pivot = axis < 0 ? ndim + axis + 1 : axis;
std::vector<IndexExpr> oshape;
oshape.reserve(ndim + num_newaxis);
for (int i = 0; i < pivot; ++i) {
oshape.emplace_back(data->shape[i]);
}
for (int i = 0; i < num_newaxis; ++i) {
oshape.emplace_back(1);
}
for (int i = pivot; i < ndim; ++i) {
oshape.emplace_back(data->shape[i]);
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeExpandDims(Expr data,
int axis,
int num_newaxis) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = axis;
attrs->num_newaxis = num_newaxis;
static const Op& op = Op::Get("expand_dims");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.expand_dims")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeExpandDims, args, rv);
});
RELAY_REGISTER_OP("expand_dims")
.describe(R"code(Insert `num_newaxis` axises at the position given by `axis`
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel);
} // namespace relay
} // namespace tvm
import tvm
from tvm import relay
def test_expand_dims_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
# let's mimic a batch of sequences
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.expand_dims(x, axis=2))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType(
(n, t, 1, 100), "float32")
if __name__ == "__main__":
test_expand_dims_infer_type()
......@@ -37,10 +37,18 @@ inline Tensor expand_dims(const Tensor& x,
int num_newaxis = 1,
std::string name = "tensor",
std::string tag = kBroadcast) {
int ndim = static_cast<int>(x->shape.size());
if (axis < 0) {
// Calculate offset from last dimension
axis = static_cast<int>(x->shape.size()) + axis + 1;
axis = ndim + axis + 1;
}
CHECK(-ndim - 1 <= axis && axis <= ndim)
<< "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
<< ", but got axis = " << axis
<< ", and data.ndim = " << ndim;
CHECK(num_newaxis >= 0)
<< "expand_dims only accepts `num_newaxis >= 0`"
<< ", but got num_newaxis = " << num_newaxis;
Array<Expr> new_shape;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment