Update server/text_generation_server/models/flash_causal_lm.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
drbh 2024-10-17 08:48:52 -04:00 committed by GitHub
parent 8d7448de9f
commit 3e0a82d512
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1922,7 +1922,7 @@ class FlashCausalLM(Model):
batch.adapter_meta.adapter_indices = next_adapter_indices batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_size * vocab_size)) # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
torch.log_softmax(out, -1, out=out) torch.log_softmax(out, -1, out=out)
prefill_logprobs_tensor = out prefill_logprobs_tensor = out
prefill_logprobs = torch.gather( prefill_logprobs = torch.gather(