diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index a9b9f811..cb8c742e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1075,8 +1075,19 @@ class FlashCausalLMBatch(Batch): input_ids = [0] * extra_pad + input_ids self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) else: - self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) - input_ids_padded_length.extend([extra_pad] * len(self)) + input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self)) + src_pos = 0 + for i in range(len(self)): + end_pos = (i + 1) * max_padded_input_len + start_pos = end_pos - self.input_lengths[i] + input_ids[start_pos:end_pos] = self.input_ids[ + src_pos : src_pos + self.input_lengths[i] + ] + input_ids_padded_length.append( + max_padded_input_len - self.input_lengths[i] + ) + src_pos += self.input_lengths[i] + self.input_ids = input_ids self.input_ids = F.pad( self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0 diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 02b8935d..1be36d09 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -80,7 +80,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): assert self.image_indices is not None - batch = super().filter(request_ids) + batch = super(FlashVlmCausalLMBatch, self).filter(request_ids) assert self.image_indices is not None indices = [] for i, request_id in enumerate(request_ids): @@ -106,6 +106,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): ] else: batch.cross_attention_states = None + batch.pixel_values = None return batch @classmethod