diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 28e7489e..a9b610a1 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -256,12 +256,6 @@ class MllamaCausalLM(VlmCausalLM): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -356,9 +350,9 @@ class MllamaCausalLM(VlmCausalLM): cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["cache_lengths"].zero_() - cuda_graph["cache_lengths"][ - : cache_lengths_tensor.shape[0] - ] = cache_lengths_tensor + cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = ( + cache_lengths_tensor + ) with self._forward_context( block_tables=cuda_graph["block_tables"],