text-generation-inference/server/text_generation_server/models/flash_causal_lm.py

755 lines
25 KiB
Python
Raw Normal View History

2023-04-03 17:06:42 +00:00
import torch
import torch.distributed
import numpy as np
2023-04-03 17:06:42 +00:00
from torch.nn import functional as F
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Union, Dict
2023-04-03 17:06:42 +00:00
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
)
tracer = trace.get_tracer(__name__)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
2023-04-03 17:06:42 +00:00
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
2023-04-03 17:06:42 +00:00
# cumulative sequence lengths
cu_seqlens: torch.Tensor
# cumulative query sequence lengths, only used in decode
cu_seqlens_q: Optional[torch.Tensor]
# past key values, only used in decode
past_key_values: Optional[torch.Tensor]
2023-04-03 17:06:42 +00:00
max_seqlen: int
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
offsets: List[Optional[int]]
token_offsets: List[Optional[int]]
2023-04-03 17:06:42 +00:00
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to
max_tokens: int
2023-04-03 17:06:42 +00:00
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=len(self),
max_tokens=self.max_tokens,
2023-04-03 17:06:42 +00:00
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "FlashCausalLMBatch":
2023-04-03 17:06:42 +00:00
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
input_lengths = []
offsets = []
token_offsets = []
2023-04-03 17:06:42 +00:00
all_input_ids = []
requests_idx_mapping = {}
2023-04-03 17:06:42 +00:00
next_token_choosers = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
max_tokens = 0
2023-04-03 17:06:42 +00:00
# Parse batch
for i, r in enumerate(pb.requests):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenizer(
r.inputs, truncation=True, max_length=r.truncate
)["input_ids"]
2023-04-03 17:06:42 +00:00
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
offsets.append(None)
token_offsets.append(None)
2023-04-03 17:06:42 +00:00
all_input_ids.append(tokenized_input)
2023-04-03 17:06:42 +00:00
# Position ids
position_ids.append(np.arange(0, input_length))
2023-04-03 17:06:42 +00:00
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
2023-04-03 17:06:42 +00:00
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
2023-04-03 17:06:42 +00:00
stopping_criterias.append(stopping_criteria)
2023-04-03 17:06:42 +00:00
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
2023-04-03 17:06:42 +00:00
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
2023-04-03 17:06:42 +00:00
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
2023-04-03 17:06:42 +00:00
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=None,
2023-04-03 17:06:42 +00:00
max_seqlen=max_seqlen,
past_key_values=None,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
2023-04-03 17:06:42 +00:00
all_input_ids=all_input_ids,
all_input_ids_tensor=[],
2023-04-03 17:06:42 +00:00
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
2023-04-03 17:06:42 +00:00
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(requests) == len(self):
return self
single_request = len(requests) == 1
# Cumulative length
cumulative_length = 0
# New values after filtering
requests_idx_mapping = {}
input_ids = self.input_ids.new_empty(len(requests))
position_ids = self.position_ids.new_empty(len(requests))
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
cu_seqlens_q = torch.arange(
0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
)
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
2023-04-03 17:06:42 +00:00
input_lengths = []
offsets = []
token_offsets = []
2023-04-03 17:06:42 +00:00
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
# Get length
request_input_length = self.input_lengths[idx]
# Copy tensors (GPU)
input_ids[i] = self.input_ids[idx]
position_ids[i] = self.position_ids[idx]
# Copy to tensor (CPU)
cu_seqlens[i + 1] = cumulative_length + request_input_length
max_seqlen = max(max_seqlen, request_input_length)
# Slice from past
past_key_values.append(
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
)
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
cumulative_length += request_input_length
max_tokens += request_input_length + (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
if single_request:
# Preallocate tensor for bs = 1 case
past_key_values = F.pad(
past_key_values[0],
(
0,
0,
0,
0,
0,
0,
0,
stopping_criterias[0].max_new_tokens
- stopping_criterias[0].current_tokens,
),
)
else:
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Move to GPU now that we have the whole tensor
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
return FlashCausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
total_batch_size = sum([len(b) for b in batches])
device = batches[0].input_ids.device
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
cu_seqlens = [0]
cu_seqlens_q = torch.arange(
0, total_batch_size + 1, device=device, dtype=torch.int32
)
2023-04-03 17:06:42 +00:00
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
2023-04-03 17:06:42 +00:00
# Cumulative length
cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0
2023-04-03 17:06:42 +00:00
for i, batch in enumerate(batches):
requests.extend(batch.requests)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
# Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
if len(batch) != 1:
past_key_values.append(batch.past_key_values)
else:
# past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values.append(
batch.past_key_values[:, : batch.input_lengths[0]]
)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
2023-04-03 17:06:42 +00:00
input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
2023-04-03 17:06:42 +00:00
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
2023-04-03 17:06:42 +00:00
# Cat past
past_key_values = torch.cat(past_key_values, dim=1)
# Create final tensor on GPU
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
2023-04-03 17:06:42 +00:00
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
2023-04-03 17:06:42 +00:00
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
2023-04-03 17:06:42 +00:00
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
2023-04-03 17:06:42 +00:00
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
2023-04-03 17:06:42 +00:00
)
def __len__(self):
return len(self.requests)
class FlashCausalLM(Model):
def __init__(
self,
model_cls: Type[PreTrainedModel],
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
2023-04-03 17:06:42 +00:00
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
2023-04-03 17:06:42 +00:00
else:
raise NotImplementedError("FlashCausalLM is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
2023-04-03 17:06:42 +00:00
)
self.model = (
model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize,
2023-04-03 17:06:42 +00:00
)
.eval()
.to(device)
2023-04-03 17:06:42 +00:00
)
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
decode_buffer=decode_buffer,
2023-04-03 17:06:42 +00:00
)
@property
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
2023-04-03 17:06:42 +00:00
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor],
2023-04-03 17:06:42 +00:00
max_s: int,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
2023-04-03 17:06:42 +00:00
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
2023-04-03 17:06:42 +00:00
max_s=max_s,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
2023-04-03 17:06:42 +00:00
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
if prefill and len(batch) == 1:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size = (
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
)
else:
pre_allocate_past_size = None
2023-04-03 17:06:42 +00:00
out, present = self.forward(
batch.input_ids,
batch.position_ids,
batch.cu_seqlens,
batch.cu_seqlens_q,
2023-04-03 17:06:42 +00:00
batch.max_seqlen,
batch.past_key_values,
pre_allocate_past_size,
2023-04-03 17:06:42 +00:00
)
if prefill:
if len(batch) > 1:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))
# Create batch.cu_seqlens_q for decode
batch.cu_seqlens_q = torch.arange(
0, len(batch) + 1, device=self.device, dtype=torch.int32
)
next_input_ids = batch.input_ids.new_empty(len(batch))
next_position_ids = batch.position_ids.new_empty(len(batch))
else:
prefill_logprobs = None
next_input_ids = batch.input_ids
next_position_ids = batch.position_ids
next_token_logprobs = out.new_empty(len(batch))
# Prepare past for next decode
if len(batch) > 1:
# Used to slice next batch past
past_indices = torch.empty(
present.shape[1], dtype=torch.int64, device=self.device
)
batch.past_key_values = present.new_empty(
(
present.shape[0],
present.shape[1] + len(batch.requests),
*present.shape[2:],
)
)
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
# and will run asynchronously while we do the next for loop
cumulative_length = 0
for i, input_length in enumerate(batch.input_lengths):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
# Indices to copy present at the correct place in past_key_values
torch.arange(
start_index + i,
end_index + i,
dtype=torch.int64,
device=self.device,
out=past_indices[start_index:end_index],
)
cumulative_length += input_length
# Copy from present to past_key_values
batch.past_key_values[:, past_indices] = present
# Initialize past_key_values in prefill for len(batch) == 1
elif prefill:
# present is already pre-padded
batch.past_key_values = present
2023-04-03 17:06:42 +00:00
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
stopped = True
2023-04-03 17:06:42 +00:00
# Zipped iterator
iterator = zip(
batch.input_lengths,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time
2023-04-03 17:06:42 +00:00
# For each member of the batch
for i, (
input_length,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
2023-04-03 17:06:42 +00:00
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
# only take last token logit
logits = out[end_index - 1 : end_index]
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
all_input_ids_tensor = batch.input_ids.new_empty(
input_length + stopping_criteria.max_new_tokens
)
# Copy from batch.input_ids to all_input_ids_tensor
all_input_ids_tensor[:input_length] = batch.input_ids[
start_index:end_index
]
batch.all_input_ids_tensor.append(all_input_ids_tensor)
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if len(batch) > 1:
prefill_tokens_indices[
start_index : end_index - 1
] = batch.input_ids[start_index + 1 : end_index]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
2023-04-03 17:06:42 +00:00
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].view(1, -1)
all_input_ids_tensor = batch.all_input_ids_tensor[i]
2023-04-03 17:06:42 +00:00
# Select next token
next_token_id, logprob = next_token_chooser(
2023-04-03 17:06:42 +00:00
all_input_ids_tensor[None, :input_length], logits
)
# Add to all_input_ids_tensor
next_token_id_squeezed = next_token_id.view(1)
all_input_ids_tensor[input_length] = next_token_id_squeezed
# Set values
next_input_ids[i] = next_token_id_squeezed
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
cumulative_length += input_length
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
if prefill:
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather(
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
)
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.view(-1).tolist()
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist()
cumulative_length = 0
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.offsets,
batch.token_offsets,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_input_ids_tensor,
next_token_ids,
next_token_logprobs,
)
# For each member of the batch
for i, (
request,
input_length,
offset,
token_offset,
next_token_chooser,
stopping_criteria,
all_input_ids,
all_input_ids_tensor,
next_token_id,
next_token_logprob,
) in enumerate(iterator):
start_index = cumulative_length
end_index = cumulative_length + input_length
2023-04-03 17:06:42 +00:00
# Append next token to all tokens
all_input_ids.append(next_token_id)
2023-04-03 17:06:42 +00:00
# Generated token
next_token_text, offset, token_offset = self.decode_token(
all_input_ids,
offset,
token_offset,
2023-04-03 17:06:42 +00:00
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id,
2023-04-03 17:06:42 +00:00
next_token_text,
)
if not stop:
stopped = False
2023-04-03 17:06:42 +00:00
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
2023-04-03 17:06:42 +00:00
)
generations.append(generation)
2023-04-03 17:06:42 +00:00
new_input_length = input_length + 1
2023-04-03 17:06:42 +00:00
# Update values
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None