diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index e604fd3c..9755ee6d 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -262,8 +262,8 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches) + def concatenate(cls, batches, padded_total_bs: int = 0): + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 771cc0a8..13939974 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -48,8 +48,8 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super().concatenate(batches) + def concatenate(cls, batches, padded_total_bs: int = 0): + batch = super().concatenate(batches, padded_total_bs) batch.pixel_values = None batch.pixel_attention_mask = None