working integration tests

This commit is contained in:
OlivierDehaene 2023-01-30 11:37:36 +01:00
parent 429155a26a
commit 4a538cfa49
3 changed files with 102 additions and 97 deletions

View File

@ -311,7 +311,7 @@ class CausalLM(Model):
next_batch_max_sequence_length = 0
# Results
results = []
generations: List[Generation] = []
# Zipped iterator
iterator = zip(
@ -343,7 +343,9 @@ class CausalLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.decode(next_token_id.squeeze())
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
@ -381,11 +383,9 @@ class CausalLM(Model):
)
# Prefill
if stopping_criteria.current_tokens == 0:
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs[
-new_input_length:-1
].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
@ -398,7 +398,7 @@ class CausalLM(Model):
else:
prefill_tokens = None
result = Generation(
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
@ -407,11 +407,11 @@ class CausalLM(Model):
generated_text,
)
results.append(result)
generations.append(generation)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return results, None
return generations, None
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished
@ -470,4 +470,4 @@ class CausalLM(Model):
max_sequence_length=next_batch_max_sequence_length,
keys_head_dim_last=batch.keys_head_dim_last,
)
return results, next_batch
return generations, next_batch

View File

@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
decoder_input_lengths: List[int]
decoder_logprobs: List[Optional[torch.Tensor]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@ -51,10 +50,10 @@ class Seq2SeqLMBatch(Batch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = []
@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = []
decoder_input_lengths = []
decoder_logprobs = []
# Parse batch
for r in pb.requests:
@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
decoder_logprobs.append(None)
# Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch):
past_key_values=None,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
decoder_logprobs=decoder_logprobs,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=len(pb.requests),
@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch):
requests = []
input_lengths = []
decoder_input_lengths = []
decoder_logprobs = []
next_token_choosers = []
stopping_criterias = []
@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
decoder_logprobs.extend(batch.decoder_logprobs)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@ -164,8 +158,8 @@ class Seq2SeqLMBatch(Batch):
)
# Copy to correct indices
input_ids[
start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length :]
start_index:end_index, -batch.max_input_length:
] = batch.input_ids[:, -batch.max_input_length:]
# Create padded tensor
if attention_mask is None:
@ -174,8 +168,8 @@ class Seq2SeqLMBatch(Batch):
)
# Copy to correct indices
attention_mask[
start_index:end_index, -batch.max_input_length :
] = batch.attention_mask[:, -batch.max_input_length :]
start_index:end_index, -batch.max_input_length:
] = batch.attention_mask[:, -batch.max_input_length:]
# Create padded tensor
if decoder_input_ids is None:
@ -184,8 +178,8 @@ class Seq2SeqLMBatch(Batch):
)
# Copy to correct indices
decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
start_index:end_index, -batch.max_decoder_input_length:
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:]
# Create padded tensor
if decoder_attention_mask is None:
@ -197,13 +191,13 @@ class Seq2SeqLMBatch(Batch):
# this batch. All generations are of length `batch.max_decoder_input_length`.
if batch.decoder_attention_mask is None:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length :
start_index:end_index, -batch.max_decoder_input_length:
] = 1
# If it exists, we need to index
else:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
start_index:end_index, -batch.max_decoder_input_length:
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:]
# Create padded tensor
if encoder_last_hidden_state is None:
@ -217,8 +211,8 @@ class Seq2SeqLMBatch(Batch):
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
start_index:end_index, -batch.max_input_length:, :
] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :]
# Iterate over attention layers
for j, past in enumerate(batch.past_key_values):
@ -244,11 +238,11 @@ class Seq2SeqLMBatch(Batch):
# We slice the past keys and values to remove the padding from previous batches
past_key_values[j][k][
start_index:end_index,
:,
-(batch.max_decoder_input_length - 1) :,
:,
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
start_index:end_index,
:,
-(batch.max_decoder_input_length - 1):,
:,
] = t[:, :, -(batch.max_decoder_input_length - 1):, :]
# encoder past
for k, t in enumerate(past[2:]):
@ -267,8 +261,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values[j].append(t.new_zeros(padded_t_shape))
past_key_values[j][idx][
start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :]
start_index:end_index, :, -batch.max_input_length:, :
] = t[:, :, -batch.max_input_length:, :]
start_index += batch.size
@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch):
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
decoder_logprobs=decoder_logprobs,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length=max_decoder_input_length,
)
def __len__(self):
return len(self.requests)
class Seq2SeqLM(Model):
def __init__(self, model_name: str, quantize=False):
@ -326,13 +322,13 @@ class Seq2SeqLM(Model):
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
@ -363,8 +359,8 @@ class Seq2SeqLM(Model):
)
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]:
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
@ -386,7 +382,6 @@ class Seq2SeqLM(Model):
next_batch_input_lengths = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
next_batch_decoder_logprobs = []
# Metadata
next_batch_size = 0
@ -394,14 +389,13 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
generations: List[Generation] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.decoder_input_lengths,
batch.decoder_logprobs,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
@ -411,46 +405,39 @@ class Seq2SeqLM(Model):
# For each member of the batch
for i, (
request,
input_length,
decoder_input_length,
decoder_logprobs,
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_input_ids,
request,
input_length,
decoder_input_length,
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token, logprobs = next_token_chooser(decoder_input_ids, logits)
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token])
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
new_decoder_input_length = decoder_input_length + 1
next_token_logprob = logprobs[-1, next_token]
if decoder_logprobs is None:
decoder_logprobs = next_token_logprob
else:
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token.squeeze(),
self.tokenizer.decode(
next_token.squeeze(), clean_up_tokenization_spaces=False
),
next_token_id,
next_token_text
)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
token_ids = decoder_input_ids[-new_decoder_input_length:]
output_text = self.decode(token_ids)
tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the bos token
logprobs = [float("nan")] + decoder_logprobs[
-decoder_input_length:
].tolist()
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
@ -458,27 +445,17 @@ class Seq2SeqLM(Model):
else:
seed = None
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(
request=request,
output_text=output_text,
generated_tokens=stopping_criteria.current_tokens,
tokens=tokens,
token_ids=token_ids.tolist(),
logprobs=logprobs,
reason=reason,
seed=seed,
)
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
# add to the next batch
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_decoder_logprobs.append(decoder_logprobs)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
@ -486,14 +463,39 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length, new_decoder_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
return generations, None
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if generated_texts:
if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
@ -551,11 +553,10 @@ class Seq2SeqLM(Model):
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths,
decoder_logprobs=next_batch_decoder_logprobs,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length,
)
return generated_texts, next_batch
return generations, next_batch

View File

@ -29,6 +29,10 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError
@abstractmethod
def __len__(self):
raise NotImplementedError
@dataclass
class GeneratedText: