text-generation-inference/server/text_generation/models/causal_lm.py
OlivierDehaene b1482d9048
breaking(router): modify /generate API to only return generated text (#50)
@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.
2023-02-02 15:02:04 +01:00

478 lines
18 KiB
Python

import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass
class CausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
past_key_values: Optional[List[Tuple]]
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int
max_sequence_length: int
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
input_lengths = []
# Parse batch
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
return cls(
batch_id=pb.id,
requests=pb.requests,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max(input_lengths),
)
@classmethod
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
requests = []
input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
past_key_values = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_sequence_length),
)
# We need to slice the attention mask to remove padding from previous steps
attention_mask[
start_index:end_index, -batch.max_sequence_length :
] = batch.attention_mask[:, -batch.max_sequence_length :]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
for j, past in enumerate(batch.past_key_values):
past_keys, past_values = past
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
_, num_heads, padded_sequence_length, head_dim = past_values.shape
padded_past_values_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
)
if batch.keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
else:
# seq_length is last for BLOOM
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
)
# This will run only once per layer
if j == len(past_key_values):
padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
padded_past_values = past_values.new_zeros(padded_past_values_shape)
past_key_values.append((padded_past_keys, padded_past_values))
# We slice the past keys and values to remove the padding from previous batches
if batch.keys_head_dim_last:
past_key_values[j][0][
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
:,
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
else:
past_key_values[j][0][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
keys_head_dim_last=batches[0].keys_head_dim_last,
)
def __len__(self):
return len(self.requests)
class CausalLM(Model):
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left"
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
).eval()
tokenizer.pad_token_id = (
self.model.config.pad_token_id
if self.model.config.pad_token_id is not None
else self.model.config.eos_token_id
)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
logits, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_lengths = []
next_batch_input_ids = []
next_batch_all_input_ids = []
# Metadata
next_batch_size = 0
next_batch_max_sequence_length = 0
# Results
generations: List[Generation] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
next_token_id = tokens[-1].view(1, 1)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generations, None
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [
[
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
for t in layer
]
for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_position_ids = batch.position_ids
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_attention_mask = torch.cat(
[
next_batch_attention_mask,
next_batch_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)
# Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
next_batch = CausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
attention_mask=next_batch_attention_mask,
position_ids=next_batch_position_ids,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
keys_head_dim_last=batch.keys_head_dim_last,
)
return generations, next_batch