mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
@njhill, @yk FYI generated_text was concatenated to the user prompt for legacy reason. We want to remove this behaviour as we don't think it is useful and even detrimonial to usability. We also remove the unused Vec.
478 lines
18 KiB
Python
478 lines
18 KiB
Python
import torch
|
|
|
|
from dataclasses import dataclass
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
|
from typing import Optional, Tuple, List, Type
|
|
|
|
from text_generation.models import Model
|
|
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
|
|
from text_generation.pb import generate_pb2
|
|
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
|
|
|
|
|
@dataclass
|
|
class CausalLMBatch(Batch):
|
|
batch_id: int
|
|
requests: List[generate_pb2.Request]
|
|
|
|
# Decoder values
|
|
input_ids: torch.Tensor
|
|
attention_mask: torch.Tensor
|
|
position_ids: torch.Tensor
|
|
past_key_values: Optional[List[Tuple]]
|
|
|
|
# All tokens
|
|
all_input_ids: List[torch.Tensor]
|
|
|
|
# Lengths of all generations present in the batch
|
|
input_lengths: List[int]
|
|
|
|
# Generation helpers
|
|
next_token_choosers: List[NextTokenChooser]
|
|
stopping_criterias: List[StoppingCriteria]
|
|
|
|
# Metadata used for padding
|
|
size: int
|
|
max_sequence_length: int
|
|
|
|
# Past metadata
|
|
keys_head_dim_last: bool = True
|
|
|
|
def to_pb(self) -> generate_pb2.Batch:
|
|
return generate_pb2.Batch(
|
|
id=self.batch_id,
|
|
requests=self.requests,
|
|
size=self.size,
|
|
)
|
|
|
|
@classmethod
|
|
def from_pb(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
device: torch.device,
|
|
) -> "CausalLMBatch":
|
|
inputs = []
|
|
next_token_choosers = []
|
|
stopping_criterias = []
|
|
input_lengths = []
|
|
|
|
# Parse batch
|
|
for r in pb.requests:
|
|
inputs.append(r.inputs)
|
|
input_lengths.append(r.input_length)
|
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
|
stopping_criterias.append(
|
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
|
)
|
|
|
|
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
|
tokenized_inputs = tokenizer(
|
|
inputs,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
pad_to_multiple_of=pad_to_multiple_of,
|
|
return_token_type_ids=False,
|
|
).to(device)
|
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
|
|
|
return cls(
|
|
batch_id=pb.id,
|
|
requests=pb.requests,
|
|
input_ids=tokenized_inputs["input_ids"],
|
|
attention_mask=tokenized_inputs["attention_mask"],
|
|
position_ids=position_ids,
|
|
past_key_values=None,
|
|
all_input_ids=all_input_ids,
|
|
input_lengths=input_lengths,
|
|
next_token_choosers=next_token_choosers,
|
|
stopping_criterias=stopping_criterias,
|
|
size=pb.size,
|
|
max_sequence_length=max(input_lengths),
|
|
)
|
|
|
|
@classmethod
|
|
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
|
# Used for padding
|
|
total_batch_size = sum(batch.size for batch in batches)
|
|
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
|
|
|
# Batch attributes
|
|
requests = []
|
|
input_lengths = []
|
|
all_input_ids = []
|
|
next_token_choosers = []
|
|
stopping_criterias = []
|
|
|
|
# Batch tensors
|
|
input_ids = None
|
|
attention_mask = None
|
|
position_ids = None
|
|
past_key_values = []
|
|
|
|
# Used for slicing correctly inside the tensors
|
|
# Equivalent to a cumsum on batch sizes
|
|
start_index = 0
|
|
for i, batch in enumerate(batches):
|
|
requests.extend(batch.requests)
|
|
input_lengths.extend(batch.input_lengths)
|
|
all_input_ids.extend(batch.all_input_ids)
|
|
next_token_choosers.extend(batch.next_token_choosers)
|
|
stopping_criterias.extend(batch.stopping_criterias)
|
|
|
|
# Slicing end index for this batch
|
|
end_index = start_index + batch.size
|
|
|
|
# We only concatenate batches that did at least one step
|
|
if batch.past_key_values is None:
|
|
raise ValueError("only concatenate prefilled batches")
|
|
|
|
# Create empty tensor
|
|
# input_ids is always of shape [batch_size, 1]
|
|
# We do not need to pad it
|
|
if input_ids is None:
|
|
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
|
# Copy to correct indices
|
|
input_ids[start_index:end_index] = batch.input_ids
|
|
|
|
# Create padded tensor
|
|
if attention_mask is None:
|
|
attention_mask = batch.attention_mask.new_zeros(
|
|
(total_batch_size, max_sequence_length),
|
|
)
|
|
|
|
# We need to slice the attention mask to remove padding from previous steps
|
|
attention_mask[
|
|
start_index:end_index, -batch.max_sequence_length :
|
|
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
|
|
|
# Create empty tensor
|
|
# position_ids is always of shape [batch_size, 1]
|
|
if position_ids is None:
|
|
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
|
|
position_ids[start_index:end_index] = batch.position_ids
|
|
|
|
for j, past in enumerate(batch.past_key_values):
|
|
past_keys, past_values = past
|
|
|
|
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
|
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
|
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
|
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
|
|
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
|
|
|
|
_, num_heads, padded_sequence_length, head_dim = past_values.shape
|
|
|
|
padded_past_values_shape = (
|
|
total_batch_size,
|
|
num_heads,
|
|
max_sequence_length - 1,
|
|
head_dim,
|
|
)
|
|
|
|
if batch.keys_head_dim_last:
|
|
padded_past_keys_shape = padded_past_values_shape
|
|
else:
|
|
# seq_length is last for BLOOM
|
|
padded_past_keys_shape = (
|
|
total_batch_size,
|
|
num_heads,
|
|
head_dim,
|
|
max_sequence_length - 1,
|
|
)
|
|
|
|
# This will run only once per layer
|
|
if j == len(past_key_values):
|
|
padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
|
|
padded_past_values = past_values.new_zeros(padded_past_values_shape)
|
|
past_key_values.append((padded_past_keys, padded_past_values))
|
|
|
|
# We slice the past keys and values to remove the padding from previous batches
|
|
if batch.keys_head_dim_last:
|
|
past_key_values[j][0][
|
|
start_index:end_index,
|
|
:,
|
|
-(batch.max_sequence_length - 1) :,
|
|
:,
|
|
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
|
else:
|
|
past_key_values[j][0][
|
|
start_index:end_index,
|
|
:,
|
|
:,
|
|
-(batch.max_sequence_length - 1) :,
|
|
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
|
|
|
past_key_values[j][1][
|
|
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
|
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
|
|
|
start_index += batch.size
|
|
|
|
return cls(
|
|
batch_id=batches[0].batch_id,
|
|
requests=requests,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
all_input_ids=all_input_ids,
|
|
input_lengths=input_lengths,
|
|
next_token_choosers=next_token_choosers,
|
|
stopping_criterias=stopping_criterias,
|
|
size=total_batch_size,
|
|
max_sequence_length=max_sequence_length,
|
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.requests)
|
|
|
|
|
|
class CausalLM(Model):
|
|
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
|
else:
|
|
if quantize:
|
|
raise ValueError("quantization is not available on CPU")
|
|
|
|
device = torch.device("cpu")
|
|
dtype = torch.float32
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name, revision=revision, padding_side="left"
|
|
)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
revision=revision,
|
|
torch_dtype=dtype,
|
|
device_map="auto" if torch.cuda.is_available() else None,
|
|
load_in_8bit=quantize,
|
|
).eval()
|
|
tokenizer.pad_token_id = (
|
|
self.model.config.pad_token_id
|
|
if self.model.config.pad_token_id is not None
|
|
else self.model.config.eos_token_id
|
|
)
|
|
|
|
super(CausalLM, self).__init__(
|
|
tokenizer=tokenizer,
|
|
device=device,
|
|
)
|
|
|
|
@property
|
|
def batch_type(self) -> Type[CausalLMBatch]:
|
|
return CausalLMBatch
|
|
|
|
def decode(self, generated_ids: List[int]) -> str:
|
|
return self.tokenizer.decode(
|
|
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
|
)
|
|
|
|
def forward(
|
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
# Model Forward
|
|
outputs = self.model.forward(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=True,
|
|
)
|
|
return outputs.logits, outputs.past_key_values
|
|
|
|
def generate_token(
|
|
self, batch: CausalLMBatch
|
|
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
|
context_manager = (
|
|
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
|
)
|
|
with context_manager():
|
|
logits, past = self.forward(
|
|
batch.input_ids,
|
|
batch.attention_mask,
|
|
batch.position_ids,
|
|
batch.past_key_values,
|
|
)
|
|
|
|
# List of indices to cache
|
|
next_batch_keep_indices = []
|
|
|
|
# New values for next forward
|
|
next_batch_input_lengths = []
|
|
next_batch_input_ids = []
|
|
next_batch_all_input_ids = []
|
|
|
|
# Metadata
|
|
next_batch_size = 0
|
|
next_batch_max_sequence_length = 0
|
|
|
|
# Results
|
|
generations: List[Generation] = []
|
|
|
|
# Zipped iterator
|
|
iterator = zip(
|
|
batch.requests,
|
|
batch.input_lengths,
|
|
logits,
|
|
batch.next_token_choosers,
|
|
batch.stopping_criterias,
|
|
batch.all_input_ids,
|
|
)
|
|
|
|
# For each member of the batch
|
|
for i, (
|
|
request,
|
|
input_length,
|
|
logits,
|
|
next_token_chooser,
|
|
stopping_criteria,
|
|
all_input_ids,
|
|
) in enumerate(iterator):
|
|
# Select next token
|
|
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
|
|
next_token_id = tokens[-1].view(1, 1)
|
|
|
|
# Append next token to all tokens
|
|
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
|
new_input_length = input_length + 1
|
|
|
|
# Generated token
|
|
next_token_logprob = logprobs[-1, next_token_id]
|
|
next_token_id_squeezed = next_token_id.squeeze()
|
|
next_token_text = self.tokenizer.decode(
|
|
next_token_id_squeezed,
|
|
clean_up_tokenization_spaces=False,
|
|
skip_special_tokens=False,
|
|
)
|
|
|
|
# Evaluate stopping criteria
|
|
stop, reason = stopping_criteria(
|
|
next_token_id_squeezed,
|
|
next_token_text,
|
|
)
|
|
|
|
if stop:
|
|
# Decode generated tokens
|
|
output_text = self.decode(
|
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
|
)
|
|
# Get seed
|
|
if isinstance(next_token_chooser.choice, Sampling):
|
|
seed = next_token_chooser.choice.seed
|
|
else:
|
|
seed = None
|
|
|
|
generated_text = GeneratedText(
|
|
output_text, stopping_criteria.current_tokens, reason, seed
|
|
)
|
|
else:
|
|
# Keep request in the batch
|
|
generated_text = None
|
|
next_batch_keep_indices.append(i)
|
|
next_batch_input_ids.append(next_token_id)
|
|
next_batch_all_input_ids.append(all_input_ids)
|
|
next_batch_size += 1
|
|
next_batch_input_lengths.append(new_input_length)
|
|
next_batch_max_sequence_length = max(
|
|
next_batch_max_sequence_length, new_input_length
|
|
)
|
|
|
|
# Prefill
|
|
if stopping_criteria.current_tokens == 1:
|
|
# Remove generated token to only have prefill and add nan for first prompt token
|
|
prefill_logprobs = [float("nan")] + logprobs.gather(
|
|
1, all_input_ids[1:]
|
|
).squeeze(1)[-new_input_length:-1].tolist()
|
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
|
prefill_texts = self.tokenizer.batch_decode(
|
|
prefill_token_ids,
|
|
clean_up_tokenization_spaces=False,
|
|
skip_special_tokens=False,
|
|
)
|
|
prefill_tokens = PrefillTokens(
|
|
prefill_token_ids, prefill_logprobs, prefill_texts
|
|
)
|
|
else:
|
|
prefill_tokens = None
|
|
|
|
generation = Generation(
|
|
request.id,
|
|
prefill_tokens,
|
|
next_token_id_squeezed,
|
|
next_token_logprob,
|
|
next_token_text,
|
|
generated_text,
|
|
)
|
|
|
|
generations.append(generation)
|
|
|
|
# We finished all generations in the batch; there is no next batch
|
|
if not next_batch_keep_indices:
|
|
return generations, None
|
|
|
|
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
|
# from the values of the next batch
|
|
if len(next_batch_keep_indices) != len(batch):
|
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
|
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
|
|
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
|
next_batch_past_key_values = [
|
|
[
|
|
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
|
|
for t in layer
|
|
]
|
|
for layer in past
|
|
]
|
|
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
|
next_batch_next_token_choosers = [
|
|
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
|
]
|
|
next_batch_stopping_criterias = [
|
|
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
|
]
|
|
else:
|
|
next_batch_attention_mask = batch.attention_mask
|
|
next_batch_position_ids = batch.position_ids
|
|
next_batch_past_key_values = past
|
|
next_batch_requests = batch.requests
|
|
next_batch_next_token_choosers = batch.next_token_choosers
|
|
next_batch_stopping_criterias = batch.stopping_criterias
|
|
|
|
# Update attention_mask with padding as we added a new token to input_ids
|
|
next_batch_attention_mask = torch.cat(
|
|
[
|
|
next_batch_attention_mask,
|
|
next_batch_attention_mask.new_ones(next_batch_size, 1),
|
|
],
|
|
dim=1,
|
|
)
|
|
|
|
# Update position_ids
|
|
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
|
|
|
|
next_batch = CausalLMBatch(
|
|
batch_id=batch.batch_id,
|
|
requests=next_batch_requests,
|
|
input_ids=next_batch_input_ids,
|
|
attention_mask=next_batch_attention_mask,
|
|
position_ids=next_batch_position_ids,
|
|
past_key_values=next_batch_past_key_values,
|
|
all_input_ids=next_batch_all_input_ids,
|
|
input_lengths=next_batch_input_lengths,
|
|
next_token_choosers=next_batch_next_token_choosers,
|
|
stopping_criterias=next_batch_stopping_criterias,
|
|
size=next_batch_size,
|
|
max_sequence_length=next_batch_max_sequence_length,
|
|
keys_head_dim_last=batch.keys_head_dim_last,
|
|
)
|
|
return generations, next_batch
|