From 8d7448de9fd1028514d9666fd0ead6f1937456ba Mon Sep 17 00:00:00 2001 From: David Holtz Date: Thu, 17 Oct 2024 02:53:32 +0000 Subject: [PATCH] fix: prefer inplace softmax to avoid copy --- server/text_generation_server/models/flash_causal_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c9b7decd..3b16a724 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1922,8 +1922,9 @@ class FlashCausalLM(Model): batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: - # Get prefill logprobs - prefill_logprobs_tensor = torch.log_softmax(out, -1) + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_size * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out prefill_logprobs = torch.gather( prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) )