Fixing medusa without prefix caching.

This commit is contained in:
Nicolas Patry 2024-08-13 13:13:08 +02:00
parent 4c8dcbb76d
commit b2933b72d0
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -1496,9 +1496,9 @@ class FlashCausalLM(Model):
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][
input_lengths + prefix_lens_tensor : input_lengths.shape[0]
) ] = input_lengths # + prefix_lens_tensor
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
@ -1920,7 +1920,7 @@ class FlashCausalLM(Model):
prefix_lens=prefix_lens, prefix_lens=prefix_lens,
), ),
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor + prefix_lens_tensor, input_lengths=input_lengths_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,