working integration tests

This commit is contained in:
OlivierDehaene 2023-01-30 11:37:36 +01:00
parent 046801278e
commit b2a468176d
3 changed files with 102 additions and 97 deletions

View File

@ -311,7 +311,7 @@ class CausalLM(Model):
next_batch_max_sequence_length = 0 next_batch_max_sequence_length = 0
# Results # Results
results = [] generations: List[Generation] = []
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -343,7 +343,9 @@ class CausalLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.decode(next_token_id.squeeze()) next_token_text = self.tokenizer.decode(next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
@ -381,11 +383,9 @@ class CausalLM(Model):
) )
# Prefill # Prefill
if stopping_criteria.current_tokens == 0: if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs[ prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
-new_input_length:-1
].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefill_token_ids,
@ -398,7 +398,7 @@ class CausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
result = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, next_token_id_squeezed,
@ -407,11 +407,11 @@ class CausalLM(Model):
generated_text, generated_text,
) )
results.append(result) generations.append(generation)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if not next_batch_keep_indices:
return results, None return generations, None
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished # If we finished at least one generation, we need to evict the indices of the generations that finished
@ -470,4 +470,4 @@ class CausalLM(Model):
max_sequence_length=next_batch_max_sequence_length, max_sequence_length=next_batch_max_sequence_length,
keys_head_dim_last=batch.keys_head_dim_last, keys_head_dim_last=batch.keys_head_dim_last,
) )
return results, next_batch return generations, next_batch

View File

@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
decoder_input_lengths: List[int] decoder_input_lengths: List[int]
decoder_logprobs: List[Optional[torch.Tensor]]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -51,10 +50,10 @@ class Seq2SeqLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = [] decoder_input_ids = []
decoder_input_lengths = [] decoder_input_lengths = []
decoder_logprobs = []
# Parse batch # Parse batch
for r in pb.requests: for r in pb.requests:
@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias.append( stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
) )
decoder_logprobs.append(None)
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None pad_to_multiple_of = 8 if device.type == "cuda" else None
@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch):
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
decoder_logprobs=decoder_logprobs,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=len(pb.requests), size=len(pb.requests),
@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch):
requests = [] requests = []
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
decoder_logprobs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch):
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths)
decoder_logprobs.extend(batch.decoder_logprobs)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -164,8 +158,8 @@ class Seq2SeqLMBatch(Batch):
) )
# Copy to correct indices # Copy to correct indices
input_ids[ input_ids[
start_index:end_index, -batch.max_input_length : start_index:end_index, -batch.max_input_length:
] = batch.input_ids[:, -batch.max_input_length :] ] = batch.input_ids[:, -batch.max_input_length:]
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
@ -174,8 +168,8 @@ class Seq2SeqLMBatch(Batch):
) )
# Copy to correct indices # Copy to correct indices
attention_mask[ attention_mask[
start_index:end_index, -batch.max_input_length : start_index:end_index, -batch.max_input_length:
] = batch.attention_mask[:, -batch.max_input_length :] ] = batch.attention_mask[:, -batch.max_input_length:]
# Create padded tensor # Create padded tensor
if decoder_input_ids is None: if decoder_input_ids is None:
@ -184,8 +178,8 @@ class Seq2SeqLMBatch(Batch):
) )
# Copy to correct indices # Copy to correct indices
decoder_input_ids[ decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index, -batch.max_decoder_input_length:
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :] ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:]
# Create padded tensor # Create padded tensor
if decoder_attention_mask is None: if decoder_attention_mask is None:
@ -197,13 +191,13 @@ class Seq2SeqLMBatch(Batch):
# this batch. All generations are of length `batch.max_decoder_input_length`. # this batch. All generations are of length `batch.max_decoder_input_length`.
if batch.decoder_attention_mask is None: if batch.decoder_attention_mask is None:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index, -batch.max_decoder_input_length:
] = 1 ] = 1
# If it exists, we need to index # If it exists, we need to index
else: else:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index, -batch.max_decoder_input_length:
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :] ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:]
# Create padded tensor # Create padded tensor
if encoder_last_hidden_state is None: if encoder_last_hidden_state is None:
@ -217,8 +211,8 @@ class Seq2SeqLMBatch(Batch):
# Copy to correct indices # Copy to correct indices
encoder_last_hidden_state[ encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, : start_index:end_index, -batch.max_input_length:, :
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :] ] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :]
# Iterate over attention layers # Iterate over attention layers
for j, past in enumerate(batch.past_key_values): for j, past in enumerate(batch.past_key_values):
@ -244,11 +238,11 @@ class Seq2SeqLMBatch(Batch):
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
past_key_values[j][k][ past_key_values[j][k][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_decoder_input_length - 1) :, -(batch.max_decoder_input_length - 1):,
:, :,
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :] ] = t[:, :, -(batch.max_decoder_input_length - 1):, :]
# encoder past # encoder past
for k, t in enumerate(past[2:]): for k, t in enumerate(past[2:]):
@ -267,8 +261,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values[j].append(t.new_zeros(padded_t_shape)) past_key_values[j].append(t.new_zeros(padded_t_shape))
past_key_values[j][idx][ past_key_values[j][idx][
start_index:end_index, :, -batch.max_input_length :, : start_index:end_index, :, -batch.max_input_length:, :
] = t[:, :, -batch.max_input_length :, :] ] = t[:, :, -batch.max_input_length:, :]
start_index += batch.size start_index += batch.size
@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch):
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
decoder_logprobs=decoder_logprobs,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
) )
def __len__(self):
return len(self.requests)
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_name: str, quantize=False): def __init__(self, model_name: str, quantize=False):
@ -326,13 +322,13 @@ class Seq2SeqLM(Model):
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
def forward( def forward(
self, self,
input_ids, input_ids,
attention_mask, attention_mask,
decoder_input_ids, decoder_input_ids,
decoder_attention_mask: Optional, decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional, encoder_last_hidden_state: Optional,
past_key_values: Optional = None, past_key_values: Optional = None,
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
@ -363,8 +359,8 @@ class Seq2SeqLM(Model):
) )
def generate_token( def generate_token(
self, batch: Seq2SeqLMBatch self, batch: Seq2SeqLMBatch
) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = ( context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode torch.no_grad if self.device.type == "cpu" else torch.inference_mode
@ -386,7 +382,6 @@ class Seq2SeqLM(Model):
next_batch_input_lengths = [] next_batch_input_lengths = []
next_batch_decoder_input_ids = [] next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = [] next_batch_decoder_input_lengths = []
next_batch_decoder_logprobs = []
# Metadata # Metadata
next_batch_size = 0 next_batch_size = 0
@ -394,14 +389,13 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length = 0 next_batch_max_decoder_input_length = 0
# Finished requests # Finished requests
generated_texts: List[GeneratedText] = [] generations: List[Generation] = []
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.decoder_input_lengths, batch.decoder_input_lengths,
batch.decoder_logprobs,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
@ -411,46 +405,39 @@ class Seq2SeqLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length, input_length,
decoder_input_length, decoder_input_length,
decoder_logprobs, logits,
logits, next_token_chooser,
next_token_chooser, stopping_criteria,
stopping_criteria, input_tokens,
input_tokens, decoder_input_ids,
decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token, logprobs = next_token_chooser(decoder_input_ids, logits) next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
# Append next token to decoder tokens # Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token]) decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
new_decoder_input_length = decoder_input_length + 1 new_decoder_input_length = decoder_input_length + 1
next_token_logprob = logprobs[-1, next_token] # Generated token
if decoder_logprobs is None: next_token_logprob = logprobs[-1, next_token_id]
decoder_logprobs = next_token_logprob next_token_id_squeezed = next_token_id.squeeze()
else: next_token_text = self.tokenizer.decode(next_token_id_squeezed,
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) clean_up_tokenization_spaces=False,
skip_special_tokens=False)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
next_token.squeeze(), next_token_id,
self.tokenizer.decode( next_token_text
next_token.squeeze(), clean_up_tokenization_spaces=False
),
) )
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
token_ids = decoder_input_ids[-new_decoder_input_length:] output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
output_text = self.decode(token_ids)
tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the bos token
logprobs = [float("nan")] + decoder_logprobs[
-decoder_input_length:
].tolist()
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -458,27 +445,17 @@ class Seq2SeqLM(Model):
else: else:
seed = None seed = None
# Add to the list of finished generations with the original request generated_text = GeneratedText(
generated_texts.append( output_text, stopping_criteria.current_tokens, reason, seed
GeneratedText(
request=request,
output_text=output_text,
generated_tokens=stopping_criteria.current_tokens,
tokens=tokens,
token_ids=token_ids.tolist(),
logprobs=logprobs,
reason=reason,
seed=seed,
)
) )
# add to the next batch
else: else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i) next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1 next_batch_size += 1
next_batch_input_lengths.append(input_length) next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_decoder_logprobs.append(decoder_logprobs)
next_batch_max_input_length = max( next_batch_max_input_length = max(
next_batch_max_input_length, input_length next_batch_max_input_length, input_length
) )
@ -486,14 +463,39 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length, new_decoder_input_length next_batch_max_decoder_input_length, new_decoder_input_length
) )
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if not next_batch_keep_indices:
return generated_texts, None return generations, None
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
# If we finished at least one generation, we need to evict the indices of the generations that finished # If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch # from the values of the next batch
if generated_texts: if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached # Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids = batch.input_ids[next_batch_keep_indices] next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
@ -551,11 +553,10 @@ class Seq2SeqLM(Model):
past_key_values=next_batch_past_key_values, past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths, input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths,
decoder_logprobs=next_batch_decoder_logprobs,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,
max_input_length=next_batch_max_input_length, max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length, max_decoder_input_length=next_batch_max_decoder_input_length,
) )
return generated_texts, next_batch return generations, next_batch

View File

@ -29,6 +29,10 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError raise NotImplementedError
@abstractmethod
def __len__(self):
raise NotImplementedError
@dataclass @dataclass
class GeneratedText: class GeneratedText: