diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 387118c2..830dc6c2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1507,7 +1507,7 @@ class FlashCausalLM(Model): input_lengths_tensor=cuda_graph["input_lengths"], prefix_lens=batch.prefix_lens, prefix_lens_tensor=prefix_lens_tensor, - state=cuda_graph["state"], + state=cuda_graph.get("state"), ): # Replay the graph cuda_graph["graph"].replay()