mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
[Gaudi] use pad_token_id to pad input id (#3268)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
3752143b39
commit
0627983c17
@ -95,8 +95,6 @@ RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install --no-deps -r requirements.txt && \
|
||||
bash ./dill-0.3.8-patch.sh && \
|
||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
|
||||
RUN pip install compressed-tensors==0.9.1
|
||||
|
@ -975,7 +975,7 @@ class FlashCausalLMBatch(Batch):
|
||||
valid_indices=None,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id):
|
||||
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
|
||||
block_tables = []
|
||||
for i, bt in enumerate(self.block_tables):
|
||||
@ -998,7 +998,7 @@ class FlashCausalLMBatch(Batch):
|
||||
bucketing_ctx,
|
||||
)
|
||||
self.input_ids = F.pad(
|
||||
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
|
||||
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id
|
||||
)
|
||||
|
||||
if self.position_ids.dim() == 2:
|
||||
@ -1040,7 +1040,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
|
||||
):
|
||||
# Prepare values if we need to continue prefilling
|
||||
# Speculation must be ignored while we prefill even with chunking
|
||||
@ -1064,7 +1064,7 @@ class FlashCausalLMBatch(Batch):
|
||||
for input_id in self.input_ids:
|
||||
padded = self.max_input_length - len(input_id) + extra_pad
|
||||
if padded > 0:
|
||||
input_id = [0] * padded + input_id
|
||||
input_id = [pad_token_id] * padded + input_id
|
||||
input_ids.append(input_id)
|
||||
input_ids_padded_length.append(padded)
|
||||
input_ids = np.concatenate(input_ids, dtype=np.int64)
|
||||
@ -1072,10 +1072,15 @@ class FlashCausalLMBatch(Batch):
|
||||
elif isinstance(self.input_ids, list):
|
||||
input_ids = self.input_ids[0]
|
||||
input_ids_padded_length.append(extra_pad)
|
||||
input_ids = [0] * extra_pad + input_ids
|
||||
input_ids = [pad_token_id] * extra_pad + input_ids
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
else:
|
||||
input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self))
|
||||
input_ids = torch.full(
|
||||
(max_padded_input_len * len(self),),
|
||||
pad_token_id,
|
||||
dtype=torch.int64,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
src_pos = 0
|
||||
for i in range(len(self)):
|
||||
end_pos = (i + 1) * max_padded_input_len
|
||||
@ -1090,7 +1095,7 @@ class FlashCausalLMBatch(Batch):
|
||||
self.input_ids = input_ids
|
||||
|
||||
self.input_ids = F.pad(
|
||||
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
|
||||
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=pad_token_id
|
||||
)
|
||||
|
||||
self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)
|
||||
@ -1312,8 +1317,9 @@ class FlashCausalLMBatch(Batch):
|
||||
self.prefill_next_token_indices = (
|
||||
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
||||
)
|
||||
all_input_ids_tensor = torch.zeros(
|
||||
all_input_ids_tensor = torch.full(
|
||||
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
|
||||
pad_token_id,
|
||||
dtype=torch.int64,
|
||||
device="hpu",
|
||||
)
|
||||
@ -1502,6 +1508,19 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
|
||||
self.max_seq_len_to_capture = 8192
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
elif config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = (
|
||||
config.eos_token_id[0]
|
||||
if isinstance(config.eos_token_id, list)
|
||||
else config.eos_token_id
|
||||
)
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.pad_token_id = 0
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
@ -2274,14 +2293,21 @@ class FlashCausalLM(Model):
|
||||
),
|
||||
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
|
||||
self.max_total_tokens,
|
||||
self.tokenizer.pad_token_id,
|
||||
)
|
||||
else:
|
||||
batch.prepare_for_prefill(
|
||||
batch.max_input_length, len(batch), self.max_total_tokens
|
||||
batch.max_input_length,
|
||||
len(batch),
|
||||
self.max_total_tokens,
|
||||
self.tokenizer.pad_token_id,
|
||||
)
|
||||
else:
|
||||
batch.prepare_for_decode(
|
||||
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
self.bucketing_ctx,
|
||||
self.tokenizer.pad_token_id,
|
||||
)
|
||||
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
|
||||
self.set_inputs_embeds(batch)
|
||||
|
@ -554,10 +554,10 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
return batch
|
||||
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
|
||||
):
|
||||
super().prepare_for_prefill(
|
||||
max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
|
||||
)
|
||||
|
||||
self.has_image_inputs = False
|
||||
|
@ -47,10 +47,10 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
cross_attention_states: Optional[torch.Tensor] = None
|
||||
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
|
||||
):
|
||||
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
|
||||
max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user