From f9520245336d4631a80e98dac011e3a9e2d899c9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 6 Sep 2024 16:59:11 +0200 Subject: [PATCH] Remove other tensor creation. --- .../text_generation_server/models/flash_causal_lm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index dc509a55..fe77257b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1215,13 +1215,6 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1527,6 +1520,8 @@ class FlashCausalLM(Model): cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( input_lengths + prefix_lens_tensor ) + cuda_graph["prefix_lengths"].zero_() + cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], @@ -1534,7 +1529,7 @@ class FlashCausalLM(Model): input_lengths=batch.input_lengths, input_lengths_tensor=cuda_graph["input_lengths"], prefix_lens=batch.prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + prefix_lens_tensor=cuda_graph["prefix_lengths"], state=cuda_graph.get("state"), ): # Replay the graph