Commit cb2a599d by Lianmin Zheng Committed by Tianqi Chen

[RELAY] Add softmax (#1841)

parent b313c021
......@@ -28,6 +28,7 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.sigmoid
tvm.relay.add
tvm.relay.expand_dims
tvm.relay.nn.softmax
**Level 2: Convolutions**
......
......@@ -67,6 +67,16 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
}
};
/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;
TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") {
TVM_ATTR_FIELD(axis).set_default(1)
.describe("The axis to sum over when computing softmax.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
......@@ -86,3 +86,23 @@ def conv2d(data,
return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
weight_layout, out_layout, out_dtype)
def softmax(data, axis):
r"""Computes softmax.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
.. note::
This operator can be optimized away for inference.
Parameters
----------
data: relay.Expr
The input data to the operator.
axis: int
The axis to sum over when computing softmax
"""
return _make.softmax(data, axis)
/*!
* Copyright (c) 2018 by Contributors
* \file nn.cc
* \brief Property def of nn operators.
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include "../type_relations.h"
namespace tvm {
namespace relay {
TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) {
auto attrs = make_node<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.softmax");
return CallNode::make(op, {data}, Attrs(attrs), {});
};
runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
});
RELAY_REGISTER_OP("nn.softmax")
.describe(R"code(Softmax layer.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
.. note::
This operator can be optimized away for inference.
- **data**: The input data
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
} // namespace relay
} // namespace tvm
......@@ -16,6 +16,19 @@ def test_expand_dims_infer_type():
(n, t, 1, 100), "float32")
def test_softmax():
ib = relay.ir_builder.IRBuilder()
n, d = tvm.var("n"), tvm.var("d")
x = ib.param("x", relay.ty.TensorType((n, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.nn.softmax(x, axis=1))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType((n, d), "float32")
def test_unary_op():
for op in [relay.exp,
relay.log,
......@@ -34,3 +47,4 @@ def test_unary_op():
if __name__ == "__main__":
test_expand_dims_infer_type()
test_unary_op()
test_softmax()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment