From b2a468176d8662e801899a7069188586d3294bf0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 30 Jan 2023 11:37:36 +0100 Subject: [PATCH] working integration tests --- server/text_generation/models/causal_lm.py | 20 +-- server/text_generation/models/seq2seq_lm.py | 175 ++++++++++---------- server/text_generation/models/types.py | 4 + 3 files changed, 102 insertions(+), 97 deletions(-) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 0c110dcd..735c94a5 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -311,7 +311,7 @@ class CausalLM(Model): next_batch_max_sequence_length = 0 # Results - results = [] + generations: List[Generation] = [] # Zipped iterator iterator = zip( @@ -343,7 +343,9 @@ class CausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.decode(next_token_id.squeeze()) + next_token_text = self.tokenizer.decode(next_token_id_squeezed, + clean_up_tokenization_spaces=False, + skip_special_tokens=False) # Evaluate stopping criteria stop, reason = stopping_criteria( @@ -381,11 +383,9 @@ class CausalLM(Model): ) # Prefill - if stopping_criteria.current_tokens == 0: + if stopping_criteria.current_tokens == 1: # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs[ - -new_input_length:-1 - ].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist() + prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, @@ -398,7 +398,7 @@ class CausalLM(Model): else: prefill_tokens = None - result = Generation( + generation = Generation( request.id, prefill_tokens, next_token_id_squeezed, @@ -407,11 +407,11 @@ class CausalLM(Model): generated_text, ) - results.append(result) + generations.append(generation) # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: - return results, None + return generations, None next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) # If we finished at least one generation, we need to evict the indices of the generations that finished @@ -470,4 +470,4 @@ class CausalLM(Model): max_sequence_length=next_batch_max_sequence_length, keys_head_dim_last=batch.keys_head_dim_last, ) - return results, next_batch + return generations, next_batch diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index f965ea88..ddb49078 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import GeneratedText, Batch +from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] - decoder_logprobs: List[Optional[torch.Tensor]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -51,10 +50,10 @@ class Seq2SeqLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "Seq2SeqLMBatch": """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" inputs = [] @@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids = [] decoder_input_lengths = [] - decoder_logprobs = [] # Parse batch for r in pb.requests: @@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch): stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) - decoder_logprobs.append(None) # Tokenize batch pad_to_multiple_of = 8 if device.type == "cuda" else None @@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch): past_key_values=None, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, - decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), @@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch): requests = [] input_lengths = [] decoder_input_lengths = [] - decoder_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) - decoder_logprobs.extend(batch.decoder_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -164,8 +158,8 @@ class Seq2SeqLMBatch(Batch): ) # Copy to correct indices input_ids[ - start_index:end_index, -batch.max_input_length : - ] = batch.input_ids[:, -batch.max_input_length :] + start_index:end_index, -batch.max_input_length: + ] = batch.input_ids[:, -batch.max_input_length:] # Create padded tensor if attention_mask is None: @@ -174,8 +168,8 @@ class Seq2SeqLMBatch(Batch): ) # Copy to correct indices attention_mask[ - start_index:end_index, -batch.max_input_length : - ] = batch.attention_mask[:, -batch.max_input_length :] + start_index:end_index, -batch.max_input_length: + ] = batch.attention_mask[:, -batch.max_input_length:] # Create padded tensor if decoder_input_ids is None: @@ -184,8 +178,8 @@ class Seq2SeqLMBatch(Batch): ) # Copy to correct indices decoder_input_ids[ - start_index:end_index, -batch.max_decoder_input_length : - ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :] + start_index:end_index, -batch.max_decoder_input_length: + ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:] # Create padded tensor if decoder_attention_mask is None: @@ -197,13 +191,13 @@ class Seq2SeqLMBatch(Batch): # this batch. All generations are of length `batch.max_decoder_input_length`. if batch.decoder_attention_mask is None: decoder_attention_mask[ - start_index:end_index, -batch.max_decoder_input_length : + start_index:end_index, -batch.max_decoder_input_length: ] = 1 # If it exists, we need to index else: decoder_attention_mask[ - start_index:end_index, -batch.max_decoder_input_length : - ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :] + start_index:end_index, -batch.max_decoder_input_length: + ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:] # Create padded tensor if encoder_last_hidden_state is None: @@ -217,8 +211,8 @@ class Seq2SeqLMBatch(Batch): # Copy to correct indices encoder_last_hidden_state[ - start_index:end_index, -batch.max_input_length :, : - ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :] + start_index:end_index, -batch.max_input_length:, : + ] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :] # Iterate over attention layers for j, past in enumerate(batch.past_key_values): @@ -244,11 +238,11 @@ class Seq2SeqLMBatch(Batch): # We slice the past keys and values to remove the padding from previous batches past_key_values[j][k][ - start_index:end_index, - :, - -(batch.max_decoder_input_length - 1) :, - :, - ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :] + start_index:end_index, + :, + -(batch.max_decoder_input_length - 1):, + :, + ] = t[:, :, -(batch.max_decoder_input_length - 1):, :] # encoder past for k, t in enumerate(past[2:]): @@ -267,8 +261,8 @@ class Seq2SeqLMBatch(Batch): past_key_values[j].append(t.new_zeros(padded_t_shape)) past_key_values[j][idx][ - start_index:end_index, :, -batch.max_input_length :, : - ] = t[:, :, -batch.max_input_length :, :] + start_index:end_index, :, -batch.max_input_length:, : + ] = t[:, :, -batch.max_input_length:, :] start_index += batch.size @@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, - decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length=max_decoder_input_length, ) + def __len__(self): + return len(self.requests) + class Seq2SeqLM(Model): def __init__(self, model_name: str, quantize=False): @@ -326,13 +322,13 @@ class Seq2SeqLM(Model): return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) def forward( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask: Optional, - encoder_last_hidden_state: Optional, - past_key_values: Optional = None, + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask: Optional, + encoder_last_hidden_state: Optional, + past_key_values: Optional = None, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -363,8 +359,8 @@ class Seq2SeqLM(Model): ) def generate_token( - self, batch: Seq2SeqLMBatch - ) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]: + self, batch: Seq2SeqLMBatch + ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( torch.no_grad if self.device.type == "cpu" else torch.inference_mode @@ -386,7 +382,6 @@ class Seq2SeqLM(Model): next_batch_input_lengths = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] - next_batch_decoder_logprobs = [] # Metadata next_batch_size = 0 @@ -394,14 +389,13 @@ class Seq2SeqLM(Model): next_batch_max_decoder_input_length = 0 # Finished requests - generated_texts: List[GeneratedText] = [] + generations: List[Generation] = [] # Zipped iterator iterator = zip( batch.requests, batch.input_lengths, batch.decoder_input_lengths, - batch.decoder_logprobs, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -411,46 +405,39 @@ class Seq2SeqLM(Model): # For each member of the batch for i, ( - request, - input_length, - decoder_input_length, - decoder_logprobs, - logits, - next_token_chooser, - stopping_criteria, - input_tokens, - decoder_input_ids, + request, + input_length, + decoder_input_length, + logits, + next_token_chooser, + stopping_criteria, + input_tokens, + decoder_input_ids, ) in enumerate(iterator): # Select next token - next_token, logprobs = next_token_chooser(decoder_input_ids, logits) + next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) # Append next token to decoder tokens - decoder_input_ids = torch.cat([decoder_input_ids, next_token]) + decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) new_decoder_input_length = decoder_input_length + 1 - next_token_logprob = logprobs[-1, next_token] - if decoder_logprobs is None: - decoder_logprobs = next_token_logprob - else: - decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text = self.tokenizer.decode(next_token_id_squeezed, + clean_up_tokenization_spaces=False, + skip_special_tokens=False) # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token.squeeze(), - self.tokenizer.decode( - next_token.squeeze(), clean_up_tokenization_spaces=False - ), + next_token_id, + next_token_text ) + if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - token_ids = decoder_input_ids[-new_decoder_input_length:] - output_text = self.decode(token_ids) - tokens = self.tokenizer.batch_decode(token_ids) - # Add NaN for the bos token - logprobs = [float("nan")] + decoder_logprobs[ - -decoder_input_length: - ].tolist() + output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -458,27 +445,17 @@ class Seq2SeqLM(Model): else: seed = None - # Add to the list of finished generations with the original request - generated_texts.append( - GeneratedText( - request=request, - output_text=output_text, - generated_tokens=stopping_criteria.current_tokens, - tokens=tokens, - token_ids=token_ids.tolist(), - logprobs=logprobs, - reason=reason, - seed=seed, - ) + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed ) - # add to the next batch else: + # Keep request in the batch + generated_text = None next_batch_keep_indices.append(i) next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) next_batch_size += 1 next_batch_input_lengths.append(input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length) - next_batch_decoder_logprobs.append(decoder_logprobs) next_batch_max_input_length = max( next_batch_max_input_length, input_length ) @@ -486,14 +463,39 @@ class Seq2SeqLM(Model): next_batch_max_decoder_input_length, new_decoder_input_length ) + # Prefill + if stopping_criteria.current_tokens == 1: + prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, [float("nan")], prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + generated_text, + ) + + generations.append(generation) + # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: - return generated_texts, None + return generations, None next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) # If we finished at least one generation, we need to evict the indices of the generations that finished # from the values of the next batch - if generated_texts: + if len(next_batch_keep_indices) != len(batch): # Apply indices to attention mask, past key values and other items that need to be cached next_batch_input_ids = batch.input_ids[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] @@ -551,11 +553,10 @@ class Seq2SeqLM(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths, - decoder_logprobs=next_batch_decoder_logprobs, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, max_input_length=next_batch_max_input_length, max_decoder_input_length=next_batch_max_decoder_input_length, ) - return generated_texts, next_batch + return generations, next_batch diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index de21b20b..2407da4d 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -29,6 +29,10 @@ class Batch(ABC): def concatenate(cls, batches: List["Batch"]) -> "Batch": raise NotImplementedError + @abstractmethod + def __len__(self): + raise NotImplementedError + @dataclass class GeneratedText: