[Gaudi] use pad_token_id to pad input id (#3268)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-06-17 15:07:25 +08:00 committed by GitHub
parent 3752143b39
commit 0627983c17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 16 deletions

View File

@ -95,8 +95,6 @@ RUN cd server && \
make gen-server && \
pip install --no-deps -r requirements.txt && \
bash ./dill-0.3.8-patch.sh && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
RUN pip install compressed-tensors==0.9.1

View File

@ -975,7 +975,7 @@ class FlashCausalLMBatch(Batch):
valid_indices=None,
)
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id):
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
block_tables = []
for i, bt in enumerate(self.block_tables):
@ -998,7 +998,7 @@ class FlashCausalLMBatch(Batch):
bucketing_ctx,
)
self.input_ids = F.pad(
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id
)
if self.position_ids.dim() == 2:
@ -1040,7 +1040,7 @@ class FlashCausalLMBatch(Batch):
)
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
):
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
@ -1064,7 +1064,7 @@ class FlashCausalLMBatch(Batch):
for input_id in self.input_ids:
padded = self.max_input_length - len(input_id) + extra_pad
if padded > 0:
input_id = [0] * padded + input_id
input_id = [pad_token_id] * padded + input_id
input_ids.append(input_id)
input_ids_padded_length.append(padded)
input_ids = np.concatenate(input_ids, dtype=np.int64)
@ -1072,10 +1072,15 @@ class FlashCausalLMBatch(Batch):
elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad)
input_ids = [0] * extra_pad + input_ids
input_ids = [pad_token_id] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else:
input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self))
input_ids = torch.full(
(max_padded_input_len * len(self),),
pad_token_id,
dtype=torch.int64,
device=self.input_ids.device,
)
src_pos = 0
for i in range(len(self)):
end_pos = (i + 1) * max_padded_input_len
@ -1090,7 +1095,7 @@ class FlashCausalLMBatch(Batch):
self.input_ids = input_ids
self.input_ids = F.pad(
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=pad_token_id
)
self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)
@ -1312,8 +1317,9 @@ class FlashCausalLMBatch(Batch):
self.prefill_next_token_indices = (
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
all_input_ids_tensor = torch.zeros(
all_input_ids_tensor = torch.full(
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
pad_token_id,
dtype=torch.int64,
device="hpu",
)
@ -1502,6 +1508,19 @@ class FlashCausalLM(Model):
)
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
self.max_seq_len_to_capture = 8192
if tokenizer.pad_token_id is None:
if config.pad_token_id is not None:
tokenizer.pad_token_id = config.pad_token_id
elif config.eos_token_id is not None:
tokenizer.pad_token_id = (
config.eos_token_id[0]
if isinstance(config.eos_token_id, list)
else config.eos_token_id
)
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0
super().__init__(
model_id=model_id,
model=model,
@ -2274,14 +2293,21 @@ class FlashCausalLM(Model):
),
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
self.max_total_tokens,
self.tokenizer.pad_token_id,
)
else:
batch.prepare_for_prefill(
batch.max_input_length, len(batch), self.max_total_tokens
batch.max_input_length,
len(batch),
self.max_total_tokens,
self.tokenizer.pad_token_id,
)
else:
batch.prepare_for_decode(
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
self.dtype,
self.use_contiguous_pa,
self.bucketing_ctx,
self.tokenizer.pad_token_id,
)
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
self.set_inputs_embeds(batch)

View File

@ -554,10 +554,10 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
return batch
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
):
super().prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
)
self.has_image_inputs = False

View File

@ -47,10 +47,10 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
cross_attention_states: Optional[torch.Tensor] = None
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
):
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
)
@classmethod