Commit 710af087 by Junru Shao Committed by Tianqi Chen

[Relay][Op] concatenate, reshape, transpose, copy (#1847)

parent cb2a599d
......@@ -28,6 +28,7 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.sigmoid
tvm.relay.add
tvm.relay.expand_dims
tvm.relay.concatenate
tvm.relay.nn.softmax
**Level 2: Convolutions**
......@@ -47,6 +48,9 @@ This level enables typical convnet models.
tvm.relay.zeros_like
tvm.relay.ones_like
tvm.relay.reshape
tvm.relay.copy
tvm.relay.transpose
**Level 4: Broadcast and Reductions**
......
......@@ -30,6 +30,35 @@ struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
}
}; // struct ExpandDimsAttrs
/*! \brief Attributes used in concatenate operators */
struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {
int axis;
TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The axis at which the input arrays are concatenated."
"Should lie in range `[-ndim, ndim)`.")
.set_default(0);
}
}; // struct ConcatenateAttrs
/*! \brief Attributes used in transpose operators */
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
Array<IndexExpr> axes;
TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The target axes order, reverse order if not specified.");
}
}; // struct TransposeAttrs
/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<IndexExpr> newshape;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape)
.describe("The new shape. Should be compatible with the original shape.");
}
}; // struct ReshapeAttrs
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
......@@ -300,21 +300,6 @@ def left_shift(lhs, rhs):
return _make.left_shift(lhs, rhs)
def concat(*args):
"""Concatenate the input tensors along the zero axis.
Parameters
----------
args: list of Tensor
Returns
-------
tensor: The concatenated tensor.
"""
tup = Tuple(list(args))
return _make.concat(tup)
def zeros_like(data):
"""Returns an array of zeros, with same type and shape as the input.
......@@ -345,3 +330,41 @@ def ones_like(data):
The computed result.
"""
return _make.ones_like(data)
def concatenate(data, axis):
"""Concatenate the input tensors along the given axis.
Parameters
----------
data : Union(List[relay.Expr], Tuple[relay.Expr])
A list of tensors.
axis : int
The axis along which the tensors are concatenated.
Returns
-------
result: relay.Expr
The concatenated tensor.
"""
data = list(data)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
return _make.concatenate(Tuple(data), axis)
def copy(data):
"""Copy a tensor.
Parameters
----------
data : relay.Expr
The tensor to be copied.
Returns
-------
result: relay.Expr
The copied result.
"""
return _make.copy(data)
......@@ -26,3 +26,93 @@ def expand_dims(data, axis, num_newaxis=1):
The reshaped result.
"""
return _make.expand_dims(data, axis, num_newaxis)
def transpose(data, axes=None):
"""Permutes the dimensions of an array.
Parameters
----------
data : relay.Expr
The input data to the operator.
axes : None or List[int]
The target axes order, reverse order if not specified.
Returns
-------
result : relay.Expr
The reshaped result.
"""
axes = axes or []
return _make.transpose(data, list(axes))
def reshape(data, newshape):
"""Reshapes the input array.
Example::
To give user more convenience in without doing manual shape inference,
some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}.
The significance of each is explained below:
- ``0`` copy this dimension from the input to the output shape.
Example::
- data.shape = (2,3,4), newshape = (4,0,2), result.shape = (4,3,2)
- data.shape = (2,3,4), newshape = (2,0,0), result.shape = (2,3,4)
- ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions
keeping the size of the new array same as that of the input array.
At most one dimension of shape can be -1.
Example::
- data.shape = (2,3,4), newshape = (6,1,-1), result.shape = (6,1,4)
- data.shape = (2,3,4), newshape = (3,-1,8), result.shape = (3,1,8)
- data.shape = (2,3,4), newshape = (-1,), result.shape = (24,)
- ``-2`` copy all/remainder of the input dimensions to the output shape.
Example::
- data.shape = (2,3,4), newshape = (-2,), result.shape = (2,3,4)
- data.shape = (2,3,4), newshape = (2,-2), result.shape = (2,3,4)
- data.shape = (2,3,4), newshape = (-2,1,1), result.shape = (2,3,4,1,1)
- ``-3`` use the product of two consecutive dimensions of the input shape
as the output dimension.
Example::
- data.shape = (2,3,4), newshape = (-3,4), result.shape = (6,4)
- data.shape = (2,3,4,5), newshape = (-3,-3), result.shape = (6,20)
- data.shape = (2,3,4), newshape = (0,-3), result.shape = (2,12)
- data.shape = (2,3,4), newshape = (-3,-2), result.shape = (6,4)
- ``-4`` split one dimension of the input into two dimensions passed subsequent
to -4 in shape (can contain -1).
Example::
- data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4)
- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
Parameters
----------
data : relay.Expr
The input data to the operator.
newshape : Union[int, Tuple[int], List[int]]
The new shape. Should be compatible with the original shape.
Returns
-------
result : relay.Expr
The reshaped result.
"""
if isinstance(newshape, int):
newshape = [newshape]
return _make.reshape(data, list(newshape))
/*!
* Copyright (c) 2018 by Contributors
* \file op_common.h
* \brief A set of utilities and common functionality
* for relay ops.
*/
#ifndef TVM_RELAY_OP_OP_COMMON_H_
#define TVM_RELAY_OP_OP_COMMON_H_
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <vector>
namespace tvm {
namespace relay {
template<typename T>
std::vector<T> AsVector(const Array<T> &array) {
std::vector<T> result;
result.reserve(array.size());
for (const T& ele : array) {
result.push_back(ele);
}
return result;
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_OP_COMMON_H_
......@@ -5,25 +5,29 @@
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h>
#include <vector>
#include "../op_common.h"
namespace tvm {
namespace relay {
/* relay.expand_dims */
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
bool ExpandDimsRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, output]
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const ExpandDimsAttrs* param = attrs.as<ExpandDimsAttrs>();
const auto* param = attrs.as<ExpandDimsAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int axis = param->axis;
const int num_newaxis = param->num_newaxis;
......@@ -76,6 +80,240 @@ RELAY_REGISTER_OP("expand_dims")
.set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel);
/* relay.concatenate */
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
bool ConcatenateRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
return false;
}
const auto* param = attrs.as<ConcatenateAttrs>();
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
// Sanity check: ndim and dtype.
const int ndim = static_cast<int>(first->shape.size());
const DataType dtype = first->dtype;
for (const Type& ele : tensor_tuple->fields) {
const auto& e = Downcast<TensorType>(ele);
int e_ndim = static_cast<int>(e->shape.size());
const DataType& e_dtype = e->dtype;
CHECK_EQ(e_ndim, ndim) << "relay.concatenate requires all tensors have the same ndim";
CHECK_EQ(e_dtype, dtype) << "relay.concatenate requires all tensors have the same dtype";
}
// Sanity check: axis
int axis = param->axis;
CHECK(-ndim <= axis && axis < ndim)
<< "concatenate only accepts `axis` in [-ndim, ndim)"
<< ", but got axis = " << axis
<< ", and ndim = " << ndim;
axis = axis < 0 ? ndim + axis : axis;
// Calculate shape
std::vector<IndexExpr>&& oshape = AsVector(first->shape);
IndexExpr &concat_dim = oshape[axis];
for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
concat_dim += e->shape[axis];
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype));
return true;
}
Expr MakeConcatenate(Expr data,
int axis) {
auto attrs = make_node<ConcatenateAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("concatenate");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.concatenate")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeConcatenate, args, rv);
});
RELAY_REGISTER_OP("concatenate")
.describe(R"code(Concatenate the input tensors along the given axis.
- **data** : A list of tensors.
- **axis** : The axis along which the tensors are concatenated.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1)
.add_type_rel("Concatenate", ConcatenateRel);
/* relay.transpose */
bool TransposeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* param = attrs.as<TransposeAttrs>();
const int ndim = data->shape.size();
const Array<IndexExpr>& axes = param->axes;
// check dimension match
CHECK(axes.empty() || static_cast<int>(axes.size()) == ndim)
<< "Dimension mismatch: axes has " << axes.size() << " elements"
<< ", but data.ndim = " << ndim;
// construct int_axes
std::vector<int> int_axes;
int_axes.reserve(ndim);
if (axes.empty()) {
for (int i = ndim - 1; i >= 0; --i) {
int_axes.push_back(i);
}
} else {
std::vector<int> axis_used(ndim, 0);
for (const IndexExpr& e : axes) {
const int64_t *axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
int axis = *axis_ptr;
// sanity check for axis and ndim
CHECK(-ndim <= axis && axis < ndim)
<< "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
<< ", but got axis = " << axis
<< ", and data.ndim = " << ndim;
axis = axis < 0 ? axis + ndim : axis;
// sanity check for duplication
CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis;
axis_used[axis] = 1;
int_axes.push_back(axis);
}
}
std::vector<IndexExpr> oshape;
oshape.reserve(ndim);
for (int axis : int_axes) {
oshape.push_back(data->shape[axis]);
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeTranspose(Expr data,
Array<IndexExpr> axes) {
auto attrs = make_node<TransposeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("transpose");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.transpose")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeTranspose, args, rv);
});
RELAY_REGISTER_OP("transpose")
.describe(R"code(Permutes the dimensions of an array.
- **data**: The input data to the operator.
- **axes**: The target axes order, reverse order if not specified.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Transpose", TransposeRel);
/* relay.reshape */
bool ReshapeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* param = attrs.as<ReshapeAttrs>();
reporter->Assign(types[1], TensorTypeNode::make(param->newshape, data->dtype));
return true;
}
Expr MakeReshape(Expr data,
Array<IndexExpr> newshape) {
auto attrs = make_node<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
static const Op& op = Op::Get("reshape");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.reshape")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReshape, args, rv);
});
RELAY_REGISTER_OP("reshape")
.describe(R"code(Reshapes the input array.
Example::
To give user more convenience in without doing manual shape inference,
some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}.
The significance of each is explained below:
- ``0`` copy this dimension from the input to the output shape.
Example::
- data.shape = (2,3,4), newshape = (4,0,2), result.shape = (4,3,2)
- data.shape = (2,3,4), newshape = (2,0,0), result.shape = (2,3,4)
- ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions
keeping the size of the new array same as that of the input array.
At most one dimension of shape can be -1.
Example::
- data.shape = (2,3,4), newshape = (6,1,-1), result.shape = (6,1,4)
- data.shape = (2,3,4), newshape = (3,-1,8), result.shape = (3,1,8)
- data.shape = (2,3,4), newshape = (-1,), result.shape = (24,)
- ``-2`` copy all/remainder of the input dimensions to the output shape.
Example::
- data.shape = (2,3,4), newshape = (-2,), result.shape = (2,3,4)
- data.shape = (2,3,4), newshape = (2,-2), result.shape = (2,3,4)
- data.shape = (2,3,4), newshape = (-2,1,1), result.shape = (2,3,4,1,1)
- ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension.
Example::
- data.shape = (2,3,4), newshape = (-3,4), result.shape = (6,4)
- data.shape = (2,3,4,5), newshape = (-3,-3), result.shape = (6,20)
- data.shape = (2,3,4), newshape = (0,-3), result.shape = (2,12)
- data.shape = (2,3,4), newshape = (-3,-2), result.shape = (6,4)
- ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1).
Example::
- data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4)
- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel);
} // namespace relay
} // namespace tvm
......@@ -82,18 +82,11 @@ RELAY_REGISTER_UNARY_OP("sigmoid")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
// Concat
TVM_REGISTER_API("relay.op._make.concat")
.set_body_typed<Expr(Expr)>([](Expr tuple) {
static const Op& op = Op::Get("concat");
return CallNode::make(op, { tuple }, Attrs(), {});
});
RELAY_REGISTER_OP("concat")
.set_num_inputs(1)
.add_argument("tuple", "Tuple", "The tupled tensor arguments.")
.set_support_level(1)
.add_type_rel("Concat", ConcatRel);
RELAY_REGISTER_UNARY_OP("copy")
.describe(R"code(Copy a tensor.
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
} // namespace relay
} // namespace tvm
......@@ -44,7 +44,46 @@ def test_unary_op():
assert ftype.ret_type == relay.TensorType((10, 4), "int32")
def test_concatenate_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
y = ib.param("y", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x, y) as func:
ib.ret(relay.concatenate((x, y), axis=-1))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, t, 200), "float32")
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
y = ib.param("y", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x, y) as func:
ib.ret(relay.concatenate((x, y), axis=2))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, t, 200), "float32")
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
y = ib.param("y", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x, y) as func:
ib.ret(relay.concatenate((x, y), axis=1))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, t + t, 100), "float32")
if __name__ == "__main__":
test_expand_dims_infer_type()
test_unary_op()
test_concatenate_infer_type()
test_softmax()
import tvm
from tvm import relay
def test_unary_identity():
for op in [relay.zeros_like, relay.ones_like]:
ib = relay.ir_builder.IRBuilder()
......@@ -11,3 +12,49 @@ def test_unary_identity():
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32")
def test_copy_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.copy(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, t, 100), "float32")
def test_transpose_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.transpose(x, axes=(1, 0, 2)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(t, n, 100), "float32")
def test_reshape_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20
x = ib.param("x", relay.ty.TensorType((n, t, d1, d2), "float32"))
with ib.function(x) as func:
ib.ret(relay.reshape(x, newshape=(n, t, 2000)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, t, 2000), "float32")
if __name__ == "__main__":
test_unary_identity()
test_copy_infer_type()
test_transpose_infer_type()
test_reshape_infer_type()
......@@ -2,7 +2,7 @@ import tvm
from tvm import relay
from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
from tvm.relay.ir_builder import convert, IRBuilder
from tvm.relay.op import log, add, equal, subtract, concat
from tvm.relay.op import log, add, equal, subtract
class env:
def __init__(self):
......
......@@ -7,7 +7,7 @@ from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
from tvm.relay.op import log, add, equal, subtract, concat
from tvm.relay.op import log, add, equal, subtract, concatenate
from tvm.relay.expr import Function
def assert_has_type(expr, typ, env=Environment({})):
......@@ -146,7 +146,7 @@ def test_concat():
"""
Program:
def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) {
return concat(x, y);
return concatenate((x, y), axis=0);
}
"""
ib = IRBuilder()
......@@ -154,7 +154,7 @@ def test_concat():
x = ib.param('x', ty=tensor_type(3, 2))
y = ib.param('y', ty=tensor_type(2, 2))
with ib.decl(try_concat2, x, y):
ib.ret(concat(x, y))
ib.ret(concatenate((x, y), axis=0))
fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2))
assert_decl_has_type(ib.env, try_concat2, fn_ty)
......
......@@ -38,10 +38,6 @@ inline Tensor expand_dims(const Tensor& x,
std::string name = "tensor",
std::string tag = kBroadcast) {
int ndim = static_cast<int>(x->shape.size());
if (axis < 0) {
// Calculate offset from last dimension
axis = ndim + axis + 1;
}
CHECK(-ndim - 1 <= axis && axis <= ndim)
<< "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
<< ", but got axis = " << axis
......@@ -49,7 +45,10 @@ inline Tensor expand_dims(const Tensor& x,
CHECK(num_newaxis >= 0)
<< "expand_dims only accepts `num_newaxis >= 0`"
<< ", but got num_newaxis = " << num_newaxis;
if (axis < 0) {
// Calculate offset from last dimension
axis = ndim + axis + 1;
}
Array<Expr> new_shape;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
new_shape.push_back(x->shape[i]);
......@@ -265,8 +264,13 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
int axis = 0,
std::string name = "tensor",
std::string tag = kInjective) {
int ndim = static_cast<int>(inputs[0]->shape.size());
CHECK(-ndim <= axis && axis < ndim)
<< "concatenate only accepts `axis` in [-ndim, ndim)"
<< ", but got axis = " << axis
<< ", and ndim = " << ndim;
if (axis < 0) {
axis += static_cast<int>(inputs[0]->shape.size());
axis += ndim;
}
CHECK_LT(axis, inputs[0]->shape.size()) <<
"axis out of bounds";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment