mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
working integration tests
This commit is contained in:
parent
046801278e
commit
b2a468176d
@ -311,7 +311,7 @@ class CausalLM(Model):
|
||||
next_batch_max_sequence_length = 0
|
||||
|
||||
# Results
|
||||
results = []
|
||||
generations: List[Generation] = []
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
@ -343,7 +343,9 @@ class CausalLM(Model):
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.decode(next_token_id.squeeze())
|
||||
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
@ -381,11 +383,9 @@ class CausalLM(Model):
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 0:
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + logprobs[
|
||||
-new_input_length:-1
|
||||
].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
|
||||
prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
@ -398,7 +398,7 @@ class CausalLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
result = Generation(
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
@ -407,11 +407,11 @@ class CausalLM(Model):
|
||||
generated_text,
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
generations.append(generation)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
return results, None
|
||||
return generations, None
|
||||
|
||||
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||
@ -470,4 +470,4 @@ class CausalLM(Model):
|
||||
max_sequence_length=next_batch_max_sequence_length,
|
||||
keys_head_dim_last=batch.keys_head_dim_last,
|
||||
)
|
||||
return results, next_batch
|
||||
return generations, next_batch
|
||||
|
@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText, Batch
|
||||
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
decoder_input_lengths: List[int]
|
||||
decoder_logprobs: List[Optional[torch.Tensor]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
@ -51,10 +50,10 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
) -> "Seq2SeqLMBatch":
|
||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||
inputs = []
|
||||
@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
decoder_input_ids = []
|
||||
decoder_input_lengths = []
|
||||
decoder_logprobs = []
|
||||
|
||||
# Parse batch
|
||||
for r in pb.requests:
|
||||
@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
decoder_logprobs.append(None)
|
||||
|
||||
# Tokenize batch
|
||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||
@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
past_key_values=None,
|
||||
input_lengths=input_lengths,
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
decoder_logprobs=decoder_logprobs,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=len(pb.requests),
|
||||
@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
requests = []
|
||||
input_lengths = []
|
||||
decoder_input_lengths = []
|
||||
decoder_logprobs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
requests.extend(batch.requests)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||
decoder_logprobs.extend(batch.decoder_logprobs)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
@ -164,8 +158,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
)
|
||||
# Copy to correct indices
|
||||
input_ids[
|
||||
start_index:end_index, -batch.max_input_length :
|
||||
] = batch.input_ids[:, -batch.max_input_length :]
|
||||
start_index:end_index, -batch.max_input_length:
|
||||
] = batch.input_ids[:, -batch.max_input_length:]
|
||||
|
||||
# Create padded tensor
|
||||
if attention_mask is None:
|
||||
@ -174,8 +168,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
)
|
||||
# Copy to correct indices
|
||||
attention_mask[
|
||||
start_index:end_index, -batch.max_input_length :
|
||||
] = batch.attention_mask[:, -batch.max_input_length :]
|
||||
start_index:end_index, -batch.max_input_length:
|
||||
] = batch.attention_mask[:, -batch.max_input_length:]
|
||||
|
||||
# Create padded tensor
|
||||
if decoder_input_ids is None:
|
||||
@ -184,8 +178,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
)
|
||||
# Copy to correct indices
|
||||
decoder_input_ids[
|
||||
start_index:end_index, -batch.max_decoder_input_length :
|
||||
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
|
||||
start_index:end_index, -batch.max_decoder_input_length:
|
||||
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:]
|
||||
|
||||
# Create padded tensor
|
||||
if decoder_attention_mask is None:
|
||||
@ -197,13 +191,13 @@ class Seq2SeqLMBatch(Batch):
|
||||
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||
if batch.decoder_attention_mask is None:
|
||||
decoder_attention_mask[
|
||||
start_index:end_index, -batch.max_decoder_input_length :
|
||||
start_index:end_index, -batch.max_decoder_input_length:
|
||||
] = 1
|
||||
# If it exists, we need to index
|
||||
else:
|
||||
decoder_attention_mask[
|
||||
start_index:end_index, -batch.max_decoder_input_length :
|
||||
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
|
||||
start_index:end_index, -batch.max_decoder_input_length:
|
||||
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:]
|
||||
|
||||
# Create padded tensor
|
||||
if encoder_last_hidden_state is None:
|
||||
@ -217,8 +211,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
# Copy to correct indices
|
||||
encoder_last_hidden_state[
|
||||
start_index:end_index, -batch.max_input_length :, :
|
||||
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
|
||||
start_index:end_index, -batch.max_input_length:, :
|
||||
] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :]
|
||||
|
||||
# Iterate over attention layers
|
||||
for j, past in enumerate(batch.past_key_values):
|
||||
@ -244,11 +238,11 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
past_key_values[j][k][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_decoder_input_length - 1) :,
|
||||
:,
|
||||
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_decoder_input_length - 1):,
|
||||
:,
|
||||
] = t[:, :, -(batch.max_decoder_input_length - 1):, :]
|
||||
|
||||
# encoder past
|
||||
for k, t in enumerate(past[2:]):
|
||||
@ -267,8 +261,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||
|
||||
past_key_values[j][idx][
|
||||
start_index:end_index, :, -batch.max_input_length :, :
|
||||
] = t[:, :, -batch.max_input_length :, :]
|
||||
start_index:end_index, :, -batch.max_input_length:, :
|
||||
] = t[:, :, -batch.max_input_length:, :]
|
||||
|
||||
start_index += batch.size
|
||||
|
||||
@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
past_key_values=past_key_values,
|
||||
input_lengths=input_lengths,
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
decoder_logprobs=decoder_logprobs,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch):
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
|
||||
class Seq2SeqLM(Model):
|
||||
def __init__(self, model_name: str, quantize=False):
|
||||
@ -326,13 +322,13 @@ class Seq2SeqLM(Model):
|
||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask: Optional,
|
||||
encoder_last_hidden_state: Optional,
|
||||
past_key_values: Optional = None,
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask: Optional,
|
||||
encoder_last_hidden_state: Optional,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
@ -363,8 +359,8 @@ class Seq2SeqLM(Model):
|
||||
)
|
||||
|
||||
def generate_token(
|
||||
self, batch: Seq2SeqLMBatch
|
||||
) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]:
|
||||
self, batch: Seq2SeqLMBatch
|
||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
context_manager = (
|
||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||
@ -386,7 +382,6 @@ class Seq2SeqLM(Model):
|
||||
next_batch_input_lengths = []
|
||||
next_batch_decoder_input_ids = []
|
||||
next_batch_decoder_input_lengths = []
|
||||
next_batch_decoder_logprobs = []
|
||||
|
||||
# Metadata
|
||||
next_batch_size = 0
|
||||
@ -394,14 +389,13 @@ class Seq2SeqLM(Model):
|
||||
next_batch_max_decoder_input_length = 0
|
||||
|
||||
# Finished requests
|
||||
generated_texts: List[GeneratedText] = []
|
||||
generations: List[Generation] = []
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.decoder_input_lengths,
|
||||
batch.decoder_logprobs,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
@ -411,46 +405,39 @@ class Seq2SeqLM(Model):
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
decoder_input_length,
|
||||
decoder_logprobs,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
input_tokens,
|
||||
decoder_input_ids,
|
||||
request,
|
||||
input_length,
|
||||
decoder_input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
input_tokens,
|
||||
decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token, logprobs = next_token_chooser(decoder_input_ids, logits)
|
||||
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token])
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||
new_decoder_input_length = decoder_input_length + 1
|
||||
|
||||
next_token_logprob = logprobs[-1, next_token]
|
||||
if decoder_logprobs is None:
|
||||
decoder_logprobs = next_token_logprob
|
||||
else:
|
||||
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token.squeeze(),
|
||||
self.tokenizer.decode(
|
||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
||||
),
|
||||
next_token_id,
|
||||
next_token_text
|
||||
)
|
||||
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
token_ids = decoder_input_ids[-new_decoder_input_length:]
|
||||
output_text = self.decode(token_ids)
|
||||
tokens = self.tokenizer.batch_decode(token_ids)
|
||||
# Add NaN for the bos token
|
||||
logprobs = [float("nan")] + decoder_logprobs[
|
||||
-decoder_input_length:
|
||||
].tolist()
|
||||
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
@ -458,27 +445,17 @@ class Seq2SeqLM(Model):
|
||||
else:
|
||||
seed = None
|
||||
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(
|
||||
request=request,
|
||||
output_text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
tokens=tokens,
|
||||
token_ids=token_ids.tolist(),
|
||||
logprobs=logprobs,
|
||||
reason=reason,
|
||||
seed=seed,
|
||||
)
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
# add to the next batch
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(input_length)
|
||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||
next_batch_decoder_logprobs.append(decoder_logprobs)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, input_length
|
||||
)
|
||||
@ -486,14 +463,39 @@ class Seq2SeqLM(Model):
|
||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, [float("nan")], prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
return generated_texts, None
|
||||
return generations, None
|
||||
|
||||
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||
# from the values of the next batch
|
||||
if generated_texts:
|
||||
if len(next_batch_keep_indices) != len(batch):
|
||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
|
||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||
@ -551,11 +553,10 @@ class Seq2SeqLM(Model):
|
||||
past_key_values=next_batch_past_key_values,
|
||||
input_lengths=next_batch_input_lengths,
|
||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
||||
decoder_logprobs=next_batch_decoder_logprobs,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
max_input_length=next_batch_max_input_length,
|
||||
max_decoder_input_length=next_batch_max_decoder_input_length,
|
||||
)
|
||||
return generated_texts, next_batch
|
||||
return generations, next_batch
|
||||
|
@ -29,6 +29,10 @@ class Batch(ABC):
|
||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedText:
|
||||
|
Loading…
Reference in New Issue
Block a user