mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing medusa without prefix caching.
This commit is contained in:
parent
4c8dcbb76d
commit
b2933b72d0
@ -1496,9 +1496,9 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
cuda_graph["input_lengths"][
|
||||||
input_lengths + prefix_lens_tensor
|
: input_lengths.shape[0]
|
||||||
)
|
] = input_lengths # + prefix_lens_tensor
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
@ -1920,7 +1920,7 @@ class FlashCausalLM(Model):
|
|||||||
prefix_lens=prefix_lens,
|
prefix_lens=prefix_lens,
|
||||||
),
|
),
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
input_lengths=input_lengths_tensor,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user