Commit 75654835 by nhynes Committed by Yizhi Liu

Allow log_softmax on explicit trailing dim (#1684)

parent d173e637
......@@ -410,7 +410,8 @@ NNVM_REGISTER_OP(log_softmax)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported";
CHECK(param.axis == -1 || param.axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
<< "log_softmax currently only works on last dimension";
return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
})
.set_attr<FGradient>(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment