This commit is contained in:
OlivierDehaene 2023-03-23 14:01:35 +01:00
parent e5e22993e7
commit a4df5bc64a

View File

@ -47,7 +47,10 @@ class FlashNeoXBatch(Batch):
past_key_values: Optional[torch.Tensor] past_key_values: Optional[torch.Tensor]
# All tokens # All tokens
all_input_ids: List[torch.Tensor] all_input_ids: List[List[int]]
# Lengths of all generations present in the batch
input_lengths: List[int]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -70,6 +73,9 @@ class FlashNeoXBatch(Batch):
cu_seqlens = [0] cu_seqlens = [0]
max_seqlen = 0 max_seqlen = 0
input_lengths = []
all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -81,9 +87,11 @@ class FlashNeoXBatch(Batch):
.squeeze(0) .squeeze(0)
) )
input_ids.append(tokenized_input) input_ids.append(tokenized_input)
all_input_ids.append(tokenized_input.tolist())
position_ids.append( position_ids.append(
torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device) torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device)
) )
input_lengths.append(len(tokenized_input))
cu_seqlens.append(len(tokenized_input)) cu_seqlens.append(len(tokenized_input))
max_seqlen = max(max_seqlen, len(tokenized_input)) max_seqlen = max(max_seqlen, len(tokenized_input))
@ -92,7 +100,6 @@ class FlashNeoXBatch(Batch):
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
) )
all_input_ids = input_ids
input_ids = torch.concat(input_ids).unsqueeze(1) input_ids = torch.concat(input_ids).unsqueeze(1)
position_ids = torch.concat(position_ids) position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
@ -105,6 +112,7 @@ class FlashNeoXBatch(Batch):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
past_key_values=None, past_key_values=None,
input_lengths=input_lengths,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
@ -191,6 +199,8 @@ class FlashNeoX(Model):
batch.past_key_values, batch.past_key_values,
) )
device = out.device
# List of indices to cache # List of indices to cache
next_batch_keep_indices = [] next_batch_keep_indices = []
@ -200,14 +210,19 @@ class FlashNeoX(Model):
next_batch_cu_seqlens = [0] next_batch_cu_seqlens = [0]
next_batch_max_seqlen = 0 next_batch_max_seqlen = 0
next_batch_past_key_values = [] next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
# Cumulative length
cumulative_length = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
@ -216,14 +231,14 @@ class FlashNeoX(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = batch.cu_seqlens[i] start_index = cumulative_length
end_index = batch.cu_seqlens[i + 1] end_index = cumulative_length + input_length
seq_length = (end_index - start_index).item()
if batch.past_key_values is None: if batch.past_key_values is None:
# Prefill mode # Prefill mode
@ -236,23 +251,28 @@ class FlashNeoX(Model):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits all_input_ids, logits
) )
next_token_id = next_token_id.to("cpu")
logprobs = logprobs.to("cpu")
next_token_id_squeezed = next_token_id.squeeze()
next_token_id_item = next_token_id_squeezed.item()
# Append next token to all tokens # Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)]) all_input_ids.append(next_token_id_item)
new_input_length = seq_length + 1 # all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)])
new_input_length = input_length + 1
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.decode_token( next_token_text = self.decode_token(
next_token_id_squeezed, next_token_id_item,
) )
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
next_token_id_squeezed, next_token_id_item,
next_token_text, next_token_text,
) )
@ -279,10 +299,11 @@ class FlashNeoX(Model):
generated_text = None generated_text = None
next_batch_keep_indices.append(i) next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id) next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(seq_length) next_batch_position_ids.append(input_length)
next_batch_cu_seqlens.append( next_batch_cu_seqlens.append(
next_batch_cu_seqlens[i] + new_input_length next_batch_cu_seqlens[i] + new_input_length
) )
next_batch_input_lengths.append(new_input_length)
next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids.append(all_input_ids)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
@ -290,7 +311,7 @@ class FlashNeoX(Model):
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather( prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:].unsqueeze(1) 1, torch.tensor(all_input_ids[1:]).unsqueeze(1)
).squeeze(1)[:-1].tolist() ).squeeze(1)[:-1].tolist()
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
@ -307,14 +328,15 @@ class FlashNeoX(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, next_token_id_item,
next_token_logprob, next_token_logprob,
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_item in self.all_special_ids,
generated_text, generated_text,
) )
generations.append(generation) generations.append(generation)
cumulative_length += input_length
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if not next_batch_keep_indices:
@ -337,7 +359,6 @@ class FlashNeoX(Model):
next_batch_stopping_criterias = batch.stopping_criterias next_batch_stopping_criterias = batch.stopping_criterias
# Create final next batch tensors # Create final next batch tensors
device = out.device
next_batch_position_ids = torch.tensor( next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32, device=device next_batch_position_ids, dtype=torch.int32, device=device
) )
@ -348,7 +369,7 @@ class FlashNeoX(Model):
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0) next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
next_batch_past_key_values = torch.concat(next_batch_past_key_values) next_batch_past_key_values = torch.concat(next_batch_past_key_values)
else: else:
next_batch_input_ids = next_batch_input_ids[0] next_batch_input_ids = next_batch_input_ids[0].to(device)
next_batch_past_key_values = next_batch_past_key_values[0] next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashNeoXBatch( next_batch = FlashNeoXBatch(
@ -359,6 +380,7 @@ class FlashNeoX(Model):
cu_seqlens=next_batch_cu_seqlens, cu_seqlens=next_batch_cu_seqlens,
max_seqlen=next_batch_max_seqlen, max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values, past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,