mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
faster
This commit is contained in:
parent
e5e22993e7
commit
a4df5bc64a
@ -47,7 +47,10 @@ class FlashNeoXBatch(Batch):
|
||||
past_key_values: Optional[torch.Tensor]
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[torch.Tensor]
|
||||
all_input_ids: List[List[int]]
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
@ -70,6 +73,9 @@ class FlashNeoXBatch(Batch):
|
||||
cu_seqlens = [0]
|
||||
max_seqlen = 0
|
||||
|
||||
input_lengths = []
|
||||
all_input_ids = []
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
@ -81,9 +87,11 @@ class FlashNeoXBatch(Batch):
|
||||
.squeeze(0)
|
||||
)
|
||||
input_ids.append(tokenized_input)
|
||||
all_input_ids.append(tokenized_input.tolist())
|
||||
position_ids.append(
|
||||
torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device)
|
||||
)
|
||||
input_lengths.append(len(tokenized_input))
|
||||
cu_seqlens.append(len(tokenized_input))
|
||||
max_seqlen = max(max_seqlen, len(tokenized_input))
|
||||
|
||||
@ -92,7 +100,6 @@ class FlashNeoXBatch(Batch):
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
|
||||
all_input_ids = input_ids
|
||||
input_ids = torch.concat(input_ids).unsqueeze(1)
|
||||
position_ids = torch.concat(position_ids)
|
||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||
@ -105,6 +112,7 @@ class FlashNeoXBatch(Batch):
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
past_key_values=None,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
@ -191,6 +199,8 @@ class FlashNeoX(Model):
|
||||
batch.past_key_values,
|
||||
)
|
||||
|
||||
device = out.device
|
||||
|
||||
# List of indices to cache
|
||||
next_batch_keep_indices = []
|
||||
|
||||
@ -200,14 +210,19 @@ class FlashNeoX(Model):
|
||||
next_batch_cu_seqlens = [0]
|
||||
next_batch_max_seqlen = 0
|
||||
next_batch_past_key_values = []
|
||||
next_batch_input_lengths = []
|
||||
next_batch_all_input_ids = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
@ -216,14 +231,14 @@ class FlashNeoX(Model):
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Indexing metadata
|
||||
start_index = batch.cu_seqlens[i]
|
||||
end_index = batch.cu_seqlens[i + 1]
|
||||
seq_length = (end_index - start_index).item()
|
||||
start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
if batch.past_key_values is None:
|
||||
# Prefill mode
|
||||
@ -236,23 +251,28 @@ class FlashNeoX(Model):
|
||||
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits
|
||||
all_input_ids, logits
|
||||
)
|
||||
next_token_id = next_token_id.to("cpu")
|
||||
logprobs = logprobs.to("cpu")
|
||||
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_id_item = next_token_id_squeezed.item()
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)])
|
||||
new_input_length = seq_length + 1
|
||||
all_input_ids.append(next_token_id_item)
|
||||
# all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)])
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.decode_token(
|
||||
next_token_id_squeezed,
|
||||
next_token_id_item,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id_squeezed,
|
||||
next_token_id_item,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
@ -279,10 +299,11 @@ class FlashNeoX(Model):
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_input_ids.append(next_token_id)
|
||||
next_batch_position_ids.append(seq_length)
|
||||
next_batch_position_ids.append(input_length)
|
||||
next_batch_cu_seqlens.append(
|
||||
next_batch_cu_seqlens[i] + new_input_length
|
||||
)
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
next_batch_all_input_ids.append(all_input_ids)
|
||||
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
|
||||
|
||||
@ -290,7 +311,7 @@ class FlashNeoX(Model):
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||
1, all_input_ids[1:].unsqueeze(1)
|
||||
1, torch.tensor(all_input_ids[1:]).unsqueeze(1)
|
||||
).squeeze(1)[:-1].tolist()
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
@ -307,14 +328,15 @@ class FlashNeoX(Model):
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_id_item,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
next_token_id_item in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
cumulative_length += input_length
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
@ -337,7 +359,6 @@ class FlashNeoX(Model):
|
||||
next_batch_stopping_criterias = batch.stopping_criterias
|
||||
|
||||
# Create final next batch tensors
|
||||
device = out.device
|
||||
next_batch_position_ids = torch.tensor(
|
||||
next_batch_position_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
@ -348,7 +369,7 @@ class FlashNeoX(Model):
|
||||
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
|
||||
next_batch_past_key_values = torch.concat(next_batch_past_key_values)
|
||||
else:
|
||||
next_batch_input_ids = next_batch_input_ids[0]
|
||||
next_batch_input_ids = next_batch_input_ids[0].to(device)
|
||||
next_batch_past_key_values = next_batch_past_key_values[0]
|
||||
|
||||
next_batch = FlashNeoXBatch(
|
||||
@ -359,6 +380,7 @@ class FlashNeoX(Model):
|
||||
cu_seqlens=next_batch_cu_seqlens,
|
||||
max_seqlen=next_batch_max_seqlen,
|
||||
past_key_values=next_batch_past_key_values,
|
||||
input_lengths=next_batch_input_lengths,
|
||||
all_input_ids=next_batch_all_input_ids,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
|
Loading…
Reference in New Issue
Block a user