Just medusa values now.

This commit is contained in:
Nicolas Patry 2024-08-13 13:02:48 +02:00
parent 549f0e9ca7
commit 4c8dcbb76d
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -1507,7 +1507,7 @@ class FlashCausalLM(Model):
input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
state=cuda_graph["state"],
state=cuda_graph.get("state"),
):
# Replay the graph
cuda_graph["graph"].replay()