diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index af9a811c..a9ecef76 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -59,7 +59,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): assert self.image_indices is not None - batch = super().filter(request_ids) + batch = super(VlmCausalLMBatch, self).filter(request_ids) assert self.image_indices is not None indices = [] for i, request_id in enumerate(request_ids): @@ -85,6 +85,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): ] else: batch.cross_attention_states = None + batch.pixel_values = None return batch @classmethod