Fixing bug in mllama.

This commit is contained in:
Nicolas Patry 2025-04-07 09:18:29 +02:00
parent 9b50bada65
commit d239884b8e
No known key found for this signature in database
GPG Key ID: 87B37D879D09DEB4

View File

@ -256,12 +256,6 @@ class MllamaCausalLM(VlmCausalLM):
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@ -356,9 +350,9 @@ class MllamaCausalLM(VlmCausalLM):
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
cache_lengths_tensor
)
with self._forward_context(
block_tables=cuda_graph["block_tables"],