Just medusa values now.

This commit is contained in:
Nicolas Patry 2024-08-13 13:02:48 +02:00
parent 549f0e9ca7
commit 4c8dcbb76d
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -1507,7 +1507,7 @@ class FlashCausalLM(Model):
input_lengths_tensor=cuda_graph["input_lengths"], input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
state=cuda_graph["state"], state=cuda_graph.get("state"),
): ):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()