diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 9883a73f..9081daa0 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -975,7 +975,7 @@ class FlashCausalLMBatch(Batch): valid_indices=None, ) - def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): + def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id): block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths] block_tables = [] for i, bt in enumerate(self.block_tables): @@ -998,7 +998,7 @@ class FlashCausalLMBatch(Batch): bucketing_ctx, ) self.input_ids = F.pad( - self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0 + self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id ) if self.position_ids.dim() == 2: @@ -1040,7 +1040,7 @@ class FlashCausalLMBatch(Batch): ) def prepare_for_prefill( - self, max_padded_input_len, max_padded_bs, max_total_tokens + self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id ): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking @@ -1064,7 +1064,7 @@ class FlashCausalLMBatch(Batch): for input_id in self.input_ids: padded = self.max_input_length - len(input_id) + extra_pad if padded > 0: - input_id = [0] * padded + input_id + input_id = [pad_token_id] * padded + input_id input_ids.append(input_id) input_ids_padded_length.append(padded) input_ids = np.concatenate(input_ids, dtype=np.int64) @@ -1072,10 +1072,15 @@ class FlashCausalLMBatch(Batch): elif isinstance(self.input_ids, list): input_ids = self.input_ids[0] input_ids_padded_length.append(extra_pad) - input_ids = [0] * extra_pad + input_ids + input_ids = [pad_token_id] * extra_pad + input_ids self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) else: - input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self)) + input_ids = torch.full( + (max_padded_input_len * len(self),), + pad_token_id, + dtype=torch.int64, + device=self.input_ids.device, + ) src_pos = 0 for i in range(len(self)): end_pos = (i + 1) * max_padded_input_len @@ -1090,7 +1095,7 @@ class FlashCausalLMBatch(Batch): self.input_ids = input_ids self.input_ids = F.pad( - self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0 + self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=pad_token_id ) self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32) @@ -1312,8 +1317,9 @@ class FlashCausalLMBatch(Batch): self.prefill_next_token_indices = ( self.prefill_next_token_indices + input_ids_padded_length_tensor ) - all_input_ids_tensor = torch.zeros( + all_input_ids_tensor = torch.full( (max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])), + pad_token_id, dtype=torch.int64, device="hpu", ) @@ -1502,6 +1508,19 @@ class FlashCausalLM(Model): ) self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true" self.max_seq_len_to_capture = 8192 + if tokenizer.pad_token_id is None: + if config.pad_token_id is not None: + tokenizer.pad_token_id = config.pad_token_id + elif config.eos_token_id is not None: + tokenizer.pad_token_id = ( + config.eos_token_id[0] + if isinstance(config.eos_token_id, list) + else config.eos_token_id + ) + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.pad_token_id = 0 super().__init__( model_id=model_id, model=model, @@ -2274,14 +2293,21 @@ class FlashCausalLM(Model): ), self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)), self.max_total_tokens, + self.tokenizer.pad_token_id, ) else: batch.prepare_for_prefill( - batch.max_input_length, len(batch), self.max_total_tokens + batch.max_input_length, + len(batch), + self.max_total_tokens, + self.tokenizer.pad_token_id, ) else: batch.prepare_for_decode( - self.dtype, self.use_contiguous_pa, self.bucketing_ctx + self.dtype, + self.use_contiguous_pa, + self.bucketing_ctx, + self.tokenizer.pad_token_id, ) if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds): self.set_inputs_embeds(batch) diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 5bd2292e..a9dcdf11 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -554,10 +554,10 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): return batch def prepare_for_prefill( - self, max_padded_input_len, max_padded_bs, max_total_tokens + self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id ): super().prepare_for_prefill( - max_padded_input_len, max_padded_bs, max_total_tokens + max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id ) self.has_image_inputs = False diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 1be36d09..a26b9111 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -47,10 +47,10 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): cross_attention_states: Optional[torch.Tensor] = None def prepare_for_prefill( - self, max_padded_input_len, max_padded_bs, max_total_tokens + self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id ): super(FlashVlmCausalLMBatch, self).prepare_for_prefill( - max_padded_input_len, max_padded_bs, max_total_tokens + max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id ) @classmethod