Commit 7e34988e by Lianmin Zheng Committed by Tianqi Chen

[TOPI] Rename output tensors for better readability (#3006)

parent c64a33ed
......@@ -46,7 +46,7 @@ namespace topi {
*/
inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
const tvm::Array<tvm::Expr>& output_shape,
std::string name = "tensor",
std::string name = "T_broadcast_to",
std::string tag = kBroadcast) {
CHECK_GE(output_shape.size(), t->shape.size())
<< "Not a broadcast, output dimensionality smaller than input.\noutput: "
......@@ -66,35 +66,35 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
tag);
}
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
inline tvm::Expr Name(const tvm::Expr& a, \
const tvm::Expr& b) { \
ComputeRule; \
} \
inline tvm::Tensor Name(const tvm::Tensor& A, \
const tvm::Tensor& B, \
std::string name = "tensor", \
std::string tag = kBroadcast) { \
auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \
return detail::WithBroadcast(l, A, B, name, tag); \
} \
inline tvm::Tensor Name(const tvm::Tensor& A, \
const tvm::Expr& B, \
std::string name = "tensor", \
std::string tag = kElementWise) { \
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
inline tvm::Expr Name(const tvm::Expr& a, \
const tvm::Expr& b) { \
ComputeRule; \
} \
inline tvm::Tensor Name(const tvm::Tensor& A, \
const tvm::Tensor& B, \
std::string name = "T_" #Name, \
std::string tag = kBroadcast) { \
auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \
return detail::WithBroadcast(l, A, B, name, tag); \
} \
inline tvm::Tensor Name(const tvm::Tensor& A, \
const tvm::Expr& B, \
std::string name = "T_" #Name, \
std::string tag = kElementWise) { \
auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \
return compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
return l(A(i), B); \
}, name, tag); \
} \
inline tvm::Tensor Name(const tvm::Expr& A, \
const tvm::Tensor& B, \
std::string name = "tensor", \
std::string tag = kElementWise) { \
auto l = [&](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \
return l(A(i), B); \
}, name, tag); \
} \
inline tvm::Tensor Name(const tvm::Expr& A, \
const tvm::Tensor& B, \
std::string name = "T_" #Name, \
std::string tag = kElementWise) { \
auto l = [&](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \
return compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
return l(A, B(i)); \
}, name, tag); \
return l(A, B(i)); \
}, name, tag); \
}
......
......@@ -38,7 +38,7 @@ using namespace tvm;
// Unary intrinsic operators
#define TOPI_DECLARE_UNARY_OP(OpName) \
inline Tensor OpName(const Tensor& x, \
std::string name = "tensor", \
std::string name = "T_" #OpName, \
std::string tag = kElementWise) { \
return compute(x->shape, [&](const Array<Var>& i) { \
return ::tvm::OpName(x(i)); \
......@@ -66,7 +66,7 @@ TOPI_DECLARE_UNARY_OP(abs);
* \return A Tensor whose op member is the identity operation
*/
inline Tensor identity(const Tensor& x,
std::string name = "tensor",
std::string name = "T_identity",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return x(i);
......@@ -83,7 +83,7 @@ inline Tensor identity(const Tensor& x,
* \return A Tensor whose op member is the negation operation
*/
inline Tensor negative(const Tensor& x,
std::string name = "tensor",
std::string name = "T_negative",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return -x(i);
......@@ -100,7 +100,7 @@ inline Tensor negative(const Tensor& x,
* \return A Tensor whose op member is the logical NOT operation
*/
inline Tensor logical_not(const Tensor& x,
std::string name = "tensor",
std::string name = "T_logical_not",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return !x(i);
......@@ -117,7 +117,7 @@ inline Tensor logical_not(const Tensor& x,
* \return A Tensor whose op member is the sign
*/
inline Tensor sign(const Tensor& x,
std::string name = "tensor",
std::string name = "T_sign",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr zero = make_zero(x->dtype);
......@@ -144,7 +144,7 @@ inline Tensor sign(const Tensor& x,
inline Tensor clip(const Tensor& x,
const Expr& a_min,
const Expr& a_max,
std::string name = "tensor",
std::string name = "T_clip",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
auto min_val = tvm::cast(x->dtype, a_min);
......@@ -167,7 +167,7 @@ inline Tensor clip(const Tensor& x,
*/
inline Tensor cast(const Tensor& x,
Type type,
std::string name = "tensor",
std::string name = "T_cast",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
auto expr = x(i);
......@@ -193,7 +193,7 @@ inline Tensor cast(const Tensor& x,
* \return A Tensor whose op member is the sum operation
*/
inline Tensor elemwise_sum(const Array<Tensor>& xs,
std::string name = "tensor",
std::string name = "T_elemwise_sum",
std::string tag = kElementWise) {
CHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor.";
return compute(xs[0]->shape, [&](const Array<Var>& i) {
......@@ -219,7 +219,7 @@ inline Tensor elemwise_sum(const Array<Tensor>& xs,
inline Tensor full(const Array<Expr>& shape,
Type dtype,
const Expr fill_value,
std::string name = "tensor",
std::string name = "T_full",
std::string tag = kElementWise) {
Expr ev = cast(dtype, fill_value);
if (!ev.defined()) {
......@@ -243,7 +243,7 @@ inline Tensor full(const Array<Expr>& shape,
*/
inline Tensor full_like(const Tensor& x,
const Expr fill_value,
std::string name = "tensor",
std::string name = "T_full_like",
std::string tag = kElementWise) {
Expr ev = cast(x->dtype, fill_value);
return compute(x->shape, [&](const Array<Var>& i) {
......
......@@ -63,7 +63,7 @@ tvm::Expr Map(const tvm::Array<tvm::Expr>& exprs, T op) {
template <typename T>
inline tvm::Tensor relu(const tvm::Tensor& t,
T threshold = static_cast<T>(0),
std::string name = "tensor",
std::string name = "T_relu",
std::string tag = kElementWise) {
return tvm::compute(
t->shape,
......@@ -87,7 +87,7 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
*/
inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
double alpha = 0.1,
std::string name = "tensor",
std::string name = "T_leaky_relu",
std::string tag = kElementWise) {
return tvm::compute(
t->shape,
......@@ -114,7 +114,7 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
inline tvm::Tensor prelu(const tvm::Tensor &x,
const tvm::Tensor &slope,
const int axis = 1,
std::string name = "tensor",
std::string name = "T_prelu",
std::string tag = kBroadcast) {
CHECK((size_t)axis < x->shape.size()) <<
"Wrong axis (" << axis << ")value. ";
......@@ -171,7 +171,7 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
const tvm::Array<tvm::Expr>& pad_before,
tvm::Array<tvm::Expr> pad_after = tvm::Array<tvm::Expr>(),
Expr pad_value = Expr(),
std::string name = "tensor",
std::string name = "T_pad",
std::string tag = kElementWise) {
if (pad_after.size() < pad_before.size()) {
for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
......@@ -247,7 +247,7 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string name = "T_conv2d_nchw",
std::string tag = kConv2dNCHW) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
......@@ -298,7 +298,7 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string name = "T_conv2d_hwcn",
std::string tag = kConv2dHWCN) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
......@@ -349,7 +349,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string name = "T_depthwise_conv2d_nchw",
std::string tag = kDepthwiseConv2dNCHW) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
......@@ -382,7 +382,7 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string name = "T_depthwise_conv2d_nhwc",
std::string tag = kDepthwiseConv2dNHWC) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
......@@ -435,7 +435,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string name = "T_group_conv2d_ngchw",
std::string tag = kGroupConv2d) {
CHECK_EQ(5, I->shape.size());
CHECK_EQ(5, W->shape.size());
......
......@@ -272,8 +272,8 @@ inline Tensor global_pool(const Tensor& x,
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto dheight = tvm::reduce_axis(Range(0, height));
auto dwidth = tvm::reduce_axis(Range(0, width));
auto dheight = tvm::reduce_axis(Range(0, height), "rv1");
auto dwidth = tvm::reduce_axis(Range(0, width), "rv2");
if (pool_type == kMaxPool) {
return tvm::compute(out_shape,
......
......@@ -57,7 +57,7 @@ using namespace topi::detail;
inline Tensor expand_dims(const Tensor& x,
int axis,
int num_newaxis = 1,
std::string name = "tensor",
std::string name = "T_expand_dims",
std::string tag = kBroadcast) {
int ndim = static_cast<int>(x->shape.size());
CHECK(-ndim - 1 <= axis && axis <= ndim)
......@@ -108,7 +108,7 @@ inline Tensor expand_dims(const Tensor& x,
*/
inline Tensor transpose(const Tensor& x,
Array<Integer> axes,
std::string name = "tensor",
std::string name = "T_transpose",
std::string tag = kInjective) {
if (!axes.defined() || axes.size() == 0) {
axes = Array<Integer>();
......@@ -164,7 +164,7 @@ inline Tensor transpose(const Tensor& x,
*/
inline Tensor flip(const Tensor& x,
int axis = 0,
std::string name = "tensor",
std::string name = "T_flip",
std::string tag = kInjective) {
size_t src_tensor_dim = x->shape.size();
int axis_inp = axis;
......@@ -204,7 +204,7 @@ inline Tensor flip(const Tensor& x,
*/
inline Tensor reshape(const Tensor& x,
Array<Expr> newshape,
std::string name = "tensor",
std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
return compute(
......@@ -229,7 +229,7 @@ inline Tensor reshape(const Tensor& x,
inline Tensor squeeze(const Tensor& x,
Array<Integer> axis,
bool atleast1d = false,
std::string name = "tensor",
std::string name = "T_squeeze",
std::string tag = kInjective) {
auto ndim = x->shape.size();
std::vector<int> axis_val;
......@@ -291,7 +291,7 @@ inline Tensor squeeze(const Tensor& x,
*/
inline Tensor concatenate(const Array<Tensor>& inputs,
int axis = 0,
std::string name = "tensor",
std::string name = "T_concat",
std::string tag = kInjective) {
int ndim = static_cast<int>(inputs[0]->shape.size());
CHECK(-ndim <= axis && axis < ndim)
......@@ -355,7 +355,7 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
*/
inline Tensor stack(const Array<Tensor>& inputs,
int axis = 0,
std::string name = "tensor",
std::string name = "T_stack",
std::string tag = kInjective) {
int ndim = static_cast<int>(inputs[0]->shape.size());
CHECK(-ndim - 1 <= axis && axis <= ndim)
......@@ -408,7 +408,7 @@ inline Tensor stack(const Array<Tensor>& inputs,
inline Array<Tensor> split(const Tensor& x,
Array<Integer> split_indices,
int axis,
std::string name = "tensor",
std::string name = "T_split",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
......@@ -486,7 +486,7 @@ inline Tensor strided_slice(const Tensor& x,
const Array<Integer>& begin,
const Array<Integer>& end,
const Array<Integer>& strides,
std::string name = "tensor",
std::string name = "T_strided_slice",
std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
// Setup the ranges.
......@@ -585,7 +585,7 @@ inline Tensor strided_slice(const Tensor& x,
inline Array<Tensor> split_sections(const Tensor& x,
int num_sections,
int axis,
std::string name = "tensor",
std::string name = "T_split_sections",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
......@@ -624,7 +624,7 @@ inline Array<Tensor> split_sections(const Tensor& x,
inline Tensor take(const Tensor& a,
const Tensor& indices,
std::string mode = "clip",
std::string name = "tensor",
std::string name = "T_take",
std::string tag = kInjective) {
Array<Expr> a_shape = a->shape;
Array<Expr> out_shape = indices->shape;
......@@ -664,7 +664,7 @@ inline Tensor take(const Tensor& a,
const Tensor& indices,
int axis,
std::string mode = "clip",
std::string name = "tensor",
std::string name = "T_take",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
......@@ -738,7 +738,7 @@ inline Tensor take(const Tensor& a,
inline Tensor where(const Tensor& condition,
const Tensor& x,
const Tensor& y,
std::string name = "tensor",
std::string name = "T_where",
std::string tag = kInjective) {
CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: "
......@@ -786,7 +786,7 @@ inline Tensor where(const Tensor& condition,
inline Tensor repeat(const Tensor& x,
int repeats,
int axis,
std::string name = "tensor",
std::string name = "T_repeat",
std::string tag = kBroadcast) {
int ndim = static_cast<int>(x->shape.size());
CHECK(-ndim - 1 <= axis && axis <= ndim)
......@@ -835,7 +835,7 @@ inline Tensor repeat(const Tensor& x,
*/
inline Tensor tile(const Tensor& x,
Array<Integer> reps,
std::string name = "tensor",
std::string name = "T_tile",
std::string tag = kBroadcast) {
size_t ndim = x->shape.size();
size_t rdim = reps.size();
......@@ -892,7 +892,7 @@ inline Tensor tile(const Tensor& x,
*/
inline Tensor gather_nd(const Tensor& data,
const Tensor& indices,
std::string name = "tensor",
std::string name = "T_gather_nd",
std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
......@@ -953,7 +953,7 @@ inline tvm::Tensor matmul(const tvm::Tensor& A,
const tvm::Tensor& B,
bool trans_a = false,
bool trans_b = false,
std::string name = "tensor",
std::string name = "T_matmul",
std::string tag = kMatMul) {
tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
B->shape[trans_b ? 0 : 1]};
......@@ -979,7 +979,7 @@ inline tvm::Tensor matmul(const tvm::Tensor& A,
inline Tensor tensordot(const Tensor& A,
const tvm::Tensor& B,
int axes = 2,
std::string name = "tensor",
std::string name = "T_tensordot",
std::string tag = kMatMul) {
CHECK_GE(A->shape.size(), axes);
CHECK_GE(B->shape.size(), axes);
......@@ -1035,7 +1035,7 @@ inline Tensor tensordot(const Tensor& A,
const tvm::Tensor& B,
Array<Expr> A_axes,
Array<Expr> B_axes,
std::string name = "tensor",
std::string name = "T_tensordot",
std::string tag = kMatMul) {
CHECK_EQ(A_axes.size(), B_axes.size());
......@@ -1084,7 +1084,7 @@ inline Tensor arange(const Expr start,
const Expr stop,
const Expr step,
Type dtype,
std::string name = "tensor",
std::string name = "T_arange",
std::string tag = kInjective) {
Expr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
tvm::cast(tvm::Float(32), stop - start) / step));
......@@ -1106,7 +1106,7 @@ inline Tensor arange(const Expr start,
inline Tensor layout_transform(const Tensor& src,
const std::string& src_layout,
const std::string& dst_layout,
const std::string name = "layout_transform",
const std::string name = "T_layout_trans",
const std::string tag = kInjective) {
Layout src_layout_struct = LayoutNode::make(src_layout);
Layout dst_layout_struct = LayoutNode::make(dst_layout);
......@@ -1142,7 +1142,7 @@ inline Tensor layout_transform(const Tensor& src,
*/
inline Tensor shape(const Tensor& src,
Type dtype,
const std::string name = "shape",
const std::string name = "T_shape",
const std::string tag = kInjective) {
int ndim = static_cast<int>(src->shape.size());
Array<Expr> out_shape{ndim};
......
......@@ -47,7 +47,7 @@ def dense_default(data, weight, bias=None):
k = tvm.reduce_axis((0, in_dim), name='k')
matmul = tvm.compute((batch, out_dim), \
lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \
tag='dense')
name='T_dense', tag='dense')
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
......
......@@ -61,9 +61,11 @@ def softmax(x, axis=-1):
return tvm.exp(x[indices] - max_elem[non_reduce_indices]) / expsum[non_reduce_indices]
reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
max_elem = tvm.compute(reduced_shape, _compute_max)
expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices))
return tvm.compute(shape, lambda *indices: _normalize(max_elem, expsum, *indices))
max_elem = tvm.compute(reduced_shape, _compute_max, name='T_softmax_maxelem')
expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices),
name='T_softmax_expsum')
return tvm.compute(shape, lambda *indices: _normalize(max_elem, expsum, *indices),
name='T_softmax_norm')
@tvm.tag_scope(tag='log_softmax_output')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment