mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
working integration tests
This commit is contained in:
parent
429155a26a
commit
4a538cfa49
@ -311,7 +311,7 @@ class CausalLM(Model):
|
|||||||
next_batch_max_sequence_length = 0
|
next_batch_max_sequence_length = 0
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
results = []
|
generations: List[Generation] = []
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -343,7 +343,9 @@ class CausalLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.decode(next_token_id.squeeze())
|
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
@ -381,11 +383,9 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 0:
|
if stopping_criteria.current_tokens == 1:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + logprobs[
|
prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
|
||||||
-new_input_length:-1
|
|
||||||
].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
|
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
@ -398,7 +398,7 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
result = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
@ -407,11 +407,11 @@ class CausalLM(Model):
|
|||||||
generated_text,
|
generated_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
results.append(result)
|
generations.append(generation)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
return results, None
|
return generations, None
|
||||||
|
|
||||||
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||||
@ -470,4 +470,4 @@ class CausalLM(Model):
|
|||||||
max_sequence_length=next_batch_max_sequence_length,
|
max_sequence_length=next_batch_max_sequence_length,
|
||||||
keys_head_dim_last=batch.keys_head_dim_last,
|
keys_head_dim_last=batch.keys_head_dim_last,
|
||||||
)
|
)
|
||||||
return results, next_batch
|
return generations, next_batch
|
||||||
|
@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz
|
|||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch
|
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
decoder_input_lengths: List[int]
|
decoder_input_lengths: List[int]
|
||||||
decoder_logprobs: List[Optional[torch.Tensor]]
|
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -51,10 +50,10 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
inputs = []
|
inputs = []
|
||||||
@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
decoder_input_ids = []
|
decoder_input_ids = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
decoder_logprobs = []
|
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
decoder_logprobs.append(None)
|
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||||
@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
decoder_logprobs=decoder_logprobs,
|
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=len(pb.requests),
|
size=len(pb.requests),
|
||||||
@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
requests = []
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
decoder_logprobs = []
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||||
decoder_logprobs.extend(batch.decoder_logprobs)
|
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
@ -164,8 +158,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
input_ids[
|
input_ids[
|
||||||
start_index:end_index, -batch.max_input_length :
|
start_index:end_index, -batch.max_input_length:
|
||||||
] = batch.input_ids[:, -batch.max_input_length :]
|
] = batch.input_ids[:, -batch.max_input_length:]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
@ -174,8 +168,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
attention_mask[
|
attention_mask[
|
||||||
start_index:end_index, -batch.max_input_length :
|
start_index:end_index, -batch.max_input_length:
|
||||||
] = batch.attention_mask[:, -batch.max_input_length :]
|
] = batch.attention_mask[:, -batch.max_input_length:]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
@ -184,8 +178,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
decoder_input_ids[
|
decoder_input_ids[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
start_index:end_index, -batch.max_decoder_input_length:
|
||||||
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
|
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
@ -197,13 +191,13 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||||
if batch.decoder_attention_mask is None:
|
if batch.decoder_attention_mask is None:
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
start_index:end_index, -batch.max_decoder_input_length:
|
||||||
] = 1
|
] = 1
|
||||||
# If it exists, we need to index
|
# If it exists, we need to index
|
||||||
else:
|
else:
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
start_index:end_index, -batch.max_decoder_input_length:
|
||||||
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
|
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if encoder_last_hidden_state is None:
|
if encoder_last_hidden_state is None:
|
||||||
@ -217,8 +211,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
encoder_last_hidden_state[
|
encoder_last_hidden_state[
|
||||||
start_index:end_index, -batch.max_input_length :, :
|
start_index:end_index, -batch.max_input_length:, :
|
||||||
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
|
] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :]
|
||||||
|
|
||||||
# Iterate over attention layers
|
# Iterate over attention layers
|
||||||
for j, past in enumerate(batch.past_key_values):
|
for j, past in enumerate(batch.past_key_values):
|
||||||
@ -244,11 +238,11 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
past_key_values[j][k][
|
past_key_values[j][k][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
-(batch.max_decoder_input_length - 1) :,
|
-(batch.max_decoder_input_length - 1):,
|
||||||
:,
|
:,
|
||||||
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
|
] = t[:, :, -(batch.max_decoder_input_length - 1):, :]
|
||||||
|
|
||||||
# encoder past
|
# encoder past
|
||||||
for k, t in enumerate(past[2:]):
|
for k, t in enumerate(past[2:]):
|
||||||
@ -267,8 +261,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||||
|
|
||||||
past_key_values[j][idx][
|
past_key_values[j][idx][
|
||||||
start_index:end_index, :, -batch.max_input_length :, :
|
start_index:end_index, :, -batch.max_input_length:, :
|
||||||
] = t[:, :, -batch.max_input_length :, :]
|
] = t[:, :, -batch.max_input_length:, :]
|
||||||
|
|
||||||
start_index += batch.size
|
start_index += batch.size
|
||||||
|
|
||||||
@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
decoder_logprobs=decoder_logprobs,
|
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
max_decoder_input_length=max_decoder_input_length,
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqLM(Model):
|
class Seq2SeqLM(Model):
|
||||||
def __init__(self, model_name: str, quantize=False):
|
def __init__(self, model_name: str, quantize=False):
|
||||||
@ -326,13 +322,13 @@ class Seq2SeqLM(Model):
|
|||||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask: Optional,
|
decoder_attention_mask: Optional,
|
||||||
encoder_last_hidden_state: Optional,
|
encoder_last_hidden_state: Optional,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
@ -363,8 +359,8 @@ class Seq2SeqLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Seq2SeqLMBatch
|
self, batch: Seq2SeqLMBatch
|
||||||
) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]:
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
context_manager = (
|
context_manager = (
|
||||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||||
@ -386,7 +382,6 @@ class Seq2SeqLM(Model):
|
|||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
next_batch_decoder_input_ids = []
|
next_batch_decoder_input_ids = []
|
||||||
next_batch_decoder_input_lengths = []
|
next_batch_decoder_input_lengths = []
|
||||||
next_batch_decoder_logprobs = []
|
|
||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
@ -394,14 +389,13 @@ class Seq2SeqLM(Model):
|
|||||||
next_batch_max_decoder_input_length = 0
|
next_batch_max_decoder_input_length = 0
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generated_texts: List[GeneratedText] = []
|
generations: List[Generation] = []
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.decoder_input_lengths,
|
batch.decoder_input_lengths,
|
||||||
batch.decoder_logprobs,
|
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
@ -411,46 +405,39 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
decoder_input_length,
|
decoder_input_length,
|
||||||
decoder_logprobs,
|
logits,
|
||||||
logits,
|
next_token_chooser,
|
||||||
next_token_chooser,
|
stopping_criteria,
|
||||||
stopping_criteria,
|
input_tokens,
|
||||||
input_tokens,
|
decoder_input_ids,
|
||||||
decoder_input_ids,
|
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token, logprobs = next_token_chooser(decoder_input_ids, logits)
|
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
|
||||||
|
|
||||||
# Append next token to decoder tokens
|
# Append next token to decoder tokens
|
||||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token])
|
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||||
new_decoder_input_length = decoder_input_length + 1
|
new_decoder_input_length = decoder_input_length + 1
|
||||||
|
|
||||||
next_token_logprob = logprobs[-1, next_token]
|
# Generated token
|
||||||
if decoder_logprobs is None:
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
decoder_logprobs = next_token_logprob
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
else:
|
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
|
||||||
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
next_token.squeeze(),
|
next_token_id,
|
||||||
self.tokenizer.decode(
|
next_token_text
|
||||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
token_ids = decoder_input_ids[-new_decoder_input_length:]
|
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
||||||
output_text = self.decode(token_ids)
|
|
||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
|
||||||
# Add NaN for the bos token
|
|
||||||
logprobs = [float("nan")] + decoder_logprobs[
|
|
||||||
-decoder_input_length:
|
|
||||||
].tolist()
|
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
@ -458,27 +445,17 @@ class Seq2SeqLM(Model):
|
|||||||
else:
|
else:
|
||||||
seed = None
|
seed = None
|
||||||
|
|
||||||
# Add to the list of finished generations with the original request
|
generated_text = GeneratedText(
|
||||||
generated_texts.append(
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
GeneratedText(
|
|
||||||
request=request,
|
|
||||||
output_text=output_text,
|
|
||||||
generated_tokens=stopping_criteria.current_tokens,
|
|
||||||
tokens=tokens,
|
|
||||||
token_ids=token_ids.tolist(),
|
|
||||||
logprobs=logprobs,
|
|
||||||
reason=reason,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# add to the next batch
|
|
||||||
else:
|
else:
|
||||||
|
# Keep request in the batch
|
||||||
|
generated_text = None
|
||||||
next_batch_keep_indices.append(i)
|
next_batch_keep_indices.append(i)
|
||||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
next_batch_input_lengths.append(input_length)
|
next_batch_input_lengths.append(input_length)
|
||||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||||
next_batch_decoder_logprobs.append(decoder_logprobs)
|
|
||||||
next_batch_max_input_length = max(
|
next_batch_max_input_length = max(
|
||||||
next_batch_max_input_length, input_length
|
next_batch_max_input_length, input_length
|
||||||
)
|
)
|
||||||
@ -486,14 +463,39 @@ class Seq2SeqLM(Model):
|
|||||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if stopping_criteria.current_tokens == 1:
|
||||||
|
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
prefill_tokens = PrefillTokens(
|
||||||
|
prefill_token_ids, [float("nan")], prefill_texts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
generated_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations.append(generation)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
return generated_texts, None
|
return generations, None
|
||||||
|
|
||||||
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
||||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||||
# from the values of the next batch
|
# from the values of the next batch
|
||||||
if generated_texts:
|
if len(next_batch_keep_indices) != len(batch):
|
||||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||||
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
|
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
|
||||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||||
@ -551,11 +553,10 @@ class Seq2SeqLM(Model):
|
|||||||
past_key_values=next_batch_past_key_values,
|
past_key_values=next_batch_past_key_values,
|
||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
decoder_input_lengths=next_batch_decoder_input_lengths,
|
||||||
decoder_logprobs=next_batch_decoder_logprobs,
|
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
max_input_length=next_batch_max_input_length,
|
max_input_length=next_batch_max_input_length,
|
||||||
max_decoder_input_length=next_batch_max_decoder_input_length,
|
max_decoder_input_length=next_batch_max_decoder_input_length,
|
||||||
)
|
)
|
||||||
return generated_texts, next_batch
|
return generations, next_batch
|
||||||
|
@ -29,6 +29,10 @@ class Batch(ABC):
|
|||||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GeneratedText:
|
class GeneratedText:
|
||||||
|
Loading…
Reference in New Issue
Block a user