From b2933b72d0895cd2ee1d4fb61a53488dcf852595 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 13 Aug 2024 13:13:08 +0200 Subject: [PATCH] Fixing medusa without prefix caching. --- server/text_generation_server/models/flash_causal_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 830dc6c2..82669215 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1496,9 +1496,9 @@ class FlashCausalLM(Model): cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + cuda_graph["input_lengths"][ + : input_lengths.shape[0] + ] = input_lengths # + prefix_lens_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], @@ -1920,7 +1920,7 @@ class FlashCausalLM(Model): prefix_lens=prefix_lens, ), cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + prefix_lens_tensor, + input_lengths=input_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size,