Commit 63a3477a by Pariksheet Pinjari Committed by Tianqi Chen

Update softmax.h (#1057)

parent e15aae2b
...@@ -83,7 +83,7 @@ inline Tensor softmax(const Tensor &x, ...@@ -83,7 +83,7 @@ inline Tensor softmax(const Tensor &x,
}); });
return tvm::compute(input_shape, [&](const Array<Var> &indices) { return tvm::compute(input_shape, [&](const Array<Var> &indices) {
return _normalize(max_elem, expsum, indices); return _normalize(max_elem, expsum, indices);
}); }, name, tag);
} }
/*! /*!
...@@ -116,7 +116,7 @@ inline Tensor log_softmax(const Tensor& x, ...@@ -116,7 +116,7 @@ inline Tensor log_softmax(const Tensor& x,
return tvm::compute( return tvm::compute(
x->shape, [&](Var i, Var j) { x->shape, [&](Var i, Var j) {
return x(i, j) - max_elem(i) - tvm::log(expsum(i)); return x(i, j) - max_elem(i) - tvm::log(expsum(i));
}); }, name, tag);
} }
} // namespace nn } // namespace nn
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment