mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Fixing bug in mllama.
This commit is contained in:
parent
9b50bada65
commit
d239884b8e
@ -256,12 +256,6 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
||||||
# in a circular buffer mode.
|
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
|
||||||
max_s = min(self.max_past(), max_s)
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
@ -356,9 +350,9 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
cuda_graph["cache_lengths"].zero_()
|
cuda_graph["cache_lengths"].zero_()
|
||||||
cuda_graph["cache_lengths"][
|
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
|
||||||
: cache_lengths_tensor.shape[0]
|
cache_lengths_tensor
|
||||||
] = cache_lengths_tensor
|
)
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
Loading…
Reference in New Issue
Block a user