fix: prefer inplace softmax to avoid copy

This commit is contained in:
David Holtz 2024-10-17 02:53:32 +00:00
parent a6a0c97ed9
commit 8d7448de9f

View File

@ -1922,8 +1922,9 @@ class FlashCausalLM(Model):
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill and prefill_logprobs:
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_size * vocab_size))
torch.log_softmax(out, -1, out=out)
prefill_logprobs_tensor = out
prefill_logprobs = torch.gather(
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
)