Name |
Last commit
|
Last update |
---|---|---|
.. | ||
test_memory_buffers.py | ||
test_ops.py | ||
test_torch_functional.py |
Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab` tensor internally to hold the element-wise application of `exp`. However, by applying a loop over `logsumexp`, we can iteratively compute logsumexp outputs. Benchmarks show this uses significantly less memory to compute logprobs. Fix provided, as well as a separate memory-efficient approach for bfloat16 case.
Name |
Last commit
|
Last update |
---|---|---|
.. | ||
test_memory_buffers.py | Loading commit data... | |
test_ops.py | Loading commit data... | |
test_torch_functional.py | Loading commit data... |