mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
faster
This commit is contained in:
parent
e5e22993e7
commit
a4df5bc64a
@ -47,7 +47,10 @@ class FlashNeoXBatch(Batch):
|
|||||||
past_key_values: Optional[torch.Tensor]
|
past_key_values: Optional[torch.Tensor]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[torch.Tensor]
|
all_input_ids: List[List[int]]
|
||||||
|
|
||||||
|
# Lengths of all generations present in the batch
|
||||||
|
input_lengths: List[int]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -70,6 +73,9 @@ class FlashNeoXBatch(Batch):
|
|||||||
cu_seqlens = [0]
|
cu_seqlens = [0]
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
|
input_lengths = []
|
||||||
|
all_input_ids = []
|
||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
@ -81,9 +87,11 @@ class FlashNeoXBatch(Batch):
|
|||||||
.squeeze(0)
|
.squeeze(0)
|
||||||
)
|
)
|
||||||
input_ids.append(tokenized_input)
|
input_ids.append(tokenized_input)
|
||||||
|
all_input_ids.append(tokenized_input.tolist())
|
||||||
position_ids.append(
|
position_ids.append(
|
||||||
torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device)
|
torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device)
|
||||||
)
|
)
|
||||||
|
input_lengths.append(len(tokenized_input))
|
||||||
cu_seqlens.append(len(tokenized_input))
|
cu_seqlens.append(len(tokenized_input))
|
||||||
max_seqlen = max(max_seqlen, len(tokenized_input))
|
max_seqlen = max(max_seqlen, len(tokenized_input))
|
||||||
|
|
||||||
@ -92,7 +100,6 @@ class FlashNeoXBatch(Batch):
|
|||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
||||||
all_input_ids = input_ids
|
|
||||||
input_ids = torch.concat(input_ids).unsqueeze(1)
|
input_ids = torch.concat(input_ids).unsqueeze(1)
|
||||||
position_ids = torch.concat(position_ids)
|
position_ids = torch.concat(position_ids)
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||||
@ -105,6 +112,7 @@ class FlashNeoXBatch(Batch):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
|
input_lengths=input_lengths,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
@ -191,6 +199,8 @@ class FlashNeoX(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device = out.device
|
||||||
|
|
||||||
# List of indices to cache
|
# List of indices to cache
|
||||||
next_batch_keep_indices = []
|
next_batch_keep_indices = []
|
||||||
|
|
||||||
@ -200,14 +210,19 @@ class FlashNeoX(Model):
|
|||||||
next_batch_cu_seqlens = [0]
|
next_batch_cu_seqlens = [0]
|
||||||
next_batch_max_seqlen = 0
|
next_batch_max_seqlen = 0
|
||||||
next_batch_past_key_values = []
|
next_batch_past_key_values = []
|
||||||
|
next_batch_input_lengths = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
|
|
||||||
|
# Cumulative length
|
||||||
|
cumulative_length = 0
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
|
batch.input_lengths,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
@ -216,14 +231,14 @@ class FlashNeoX(Model):
|
|||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
|
input_length,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = batch.cu_seqlens[i]
|
start_index = cumulative_length
|
||||||
end_index = batch.cu_seqlens[i + 1]
|
end_index = cumulative_length + input_length
|
||||||
seq_length = (end_index - start_index).item()
|
|
||||||
|
|
||||||
if batch.past_key_values is None:
|
if batch.past_key_values is None:
|
||||||
# Prefill mode
|
# Prefill mode
|
||||||
@ -236,23 +251,28 @@ class FlashNeoX(Model):
|
|||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
all_input_ids.view(1, -1), logits
|
all_input_ids, logits
|
||||||
)
|
)
|
||||||
|
next_token_id = next_token_id.to("cpu")
|
||||||
|
logprobs = logprobs.to("cpu")
|
||||||
|
|
||||||
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
|
next_token_id_item = next_token_id_squeezed.item()
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)])
|
all_input_ids.append(next_token_id_item)
|
||||||
new_input_length = seq_length + 1
|
# all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)])
|
||||||
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
|
||||||
next_token_text = self.decode_token(
|
next_token_text = self.decode_token(
|
||||||
next_token_id_squeezed,
|
next_token_id_item,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
next_token_id_squeezed,
|
next_token_id_item,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -279,10 +299,11 @@ class FlashNeoX(Model):
|
|||||||
generated_text = None
|
generated_text = None
|
||||||
next_batch_keep_indices.append(i)
|
next_batch_keep_indices.append(i)
|
||||||
next_batch_input_ids.append(next_token_id)
|
next_batch_input_ids.append(next_token_id)
|
||||||
next_batch_position_ids.append(seq_length)
|
next_batch_position_ids.append(input_length)
|
||||||
next_batch_cu_seqlens.append(
|
next_batch_cu_seqlens.append(
|
||||||
next_batch_cu_seqlens[i] + new_input_length
|
next_batch_cu_seqlens[i] + new_input_length
|
||||||
)
|
)
|
||||||
|
next_batch_input_lengths.append(new_input_length)
|
||||||
next_batch_all_input_ids.append(all_input_ids)
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
|
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
|
||||||
|
|
||||||
@ -290,7 +311,7 @@ class FlashNeoX(Model):
|
|||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||||
1, all_input_ids[1:].unsqueeze(1)
|
1, torch.tensor(all_input_ids[1:]).unsqueeze(1)
|
||||||
).squeeze(1)[:-1].tolist()
|
).squeeze(1)[:-1].tolist()
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
@ -307,14 +328,15 @@ class FlashNeoX(Model):
|
|||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
next_token_id_squeezed,
|
next_token_id_item,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id_item in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
cumulative_length += input_length
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
@ -337,7 +359,6 @@ class FlashNeoX(Model):
|
|||||||
next_batch_stopping_criterias = batch.stopping_criterias
|
next_batch_stopping_criterias = batch.stopping_criterias
|
||||||
|
|
||||||
# Create final next batch tensors
|
# Create final next batch tensors
|
||||||
device = out.device
|
|
||||||
next_batch_position_ids = torch.tensor(
|
next_batch_position_ids = torch.tensor(
|
||||||
next_batch_position_ids, dtype=torch.int32, device=device
|
next_batch_position_ids, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
@ -348,7 +369,7 @@ class FlashNeoX(Model):
|
|||||||
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
|
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
|
||||||
next_batch_past_key_values = torch.concat(next_batch_past_key_values)
|
next_batch_past_key_values = torch.concat(next_batch_past_key_values)
|
||||||
else:
|
else:
|
||||||
next_batch_input_ids = next_batch_input_ids[0]
|
next_batch_input_ids = next_batch_input_ids[0].to(device)
|
||||||
next_batch_past_key_values = next_batch_past_key_values[0]
|
next_batch_past_key_values = next_batch_past_key_values[0]
|
||||||
|
|
||||||
next_batch = FlashNeoXBatch(
|
next_batch = FlashNeoXBatch(
|
||||||
@ -359,6 +380,7 @@ class FlashNeoX(Model):
|
|||||||
cu_seqlens=next_batch_cu_seqlens,
|
cu_seqlens=next_batch_cu_seqlens,
|
||||||
max_seqlen=next_batch_max_seqlen,
|
max_seqlen=next_batch_max_seqlen,
|
||||||
past_key_values=next_batch_past_key_values,
|
past_key_values=next_batch_past_key_values,
|
||||||
|
input_lengths=next_batch_input_lengths,
|
||||||
all_input_ids=next_batch_all_input_ids,
|
all_input_ids=next_batch_all_input_ids,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
|
Loading…
Reference in New Issue
Block a user