mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing medusa without prefix caching.
This commit is contained in:
parent
4c8dcbb76d
commit
b2933b72d0
@ -1496,9 +1496,9 @@ class FlashCausalLM(Model):
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
cuda_graph["input_lengths"][
|
||||
: input_lengths.shape[0]
|
||||
] = input_lengths # + prefix_lens_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
@ -1920,7 +1920,7 @@ class FlashCausalLM(Model):
|
||||
prefix_lens=prefix_lens,
|
||||
),
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||
input_lengths=input_lengths_tensor,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
|
Loading…
Reference in New Issue
Block a user