This commit is contained in:
OlivierDehaene 2023-04-14 12:33:44 +02:00
parent b6ee0ec7b0
commit 4e63d9cb28
5 changed files with 237 additions and 182 deletions

View File

@ -132,7 +132,7 @@ message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional Batch batch = 2;
Batch batch = 2;
}
message DecodeRequest {
@ -144,5 +144,5 @@ message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional Batch batch = 2;
Batch batch = 2;
}

View File

@ -7,9 +7,8 @@ use futures::future::try_join_all;
use futures::stream::StreamExt;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use flume::SendError;
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
use thiserror::Error;
use tokio::sync::{Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
@ -339,7 +338,21 @@ async fn prefill(
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
filter_send_generations(generations, entries);
let next_batch = {
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
let size = batch.requests.len();
if size == 0 {
let _ = client.clear_cache(Some(batch.id)).await;
return None;
}
batch.size = size as u32;
Some(batch)
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch
@ -361,17 +374,35 @@ async fn decode(
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
filter_send_generations(generations, entries);
let next_batch = {
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
let size = batch.requests.len();
if size == 0 {
let _ = client.clear_cache(Some(batch.id)).await;
return None;
}
batch.size = size as u32;
Some(batch)
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None
@ -398,64 +429,66 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&generation.request_id)
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
.unwrap_or(());
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let entry = entries
.remove(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))
.unwrap_or(());
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation back to infer task
// If the receive an error from the Flume channel, we need to stop generating for this
// request hence why we unwrap_or(true)
let stopped = send_generation(generation, entry).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
entry.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
// Send message
entry.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry.response_tx
.send(Ok(InferStreamResponse::Token(token)))
?;
}
Ok(stopped)
}
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
class CausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
@ -42,7 +43,6 @@ class CausalLMBatch(Batch):
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int
max_input_length: int
padding_right_offset: int
@ -53,26 +53,28 @@ class CausalLMBatch(Batch):
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
size=len(self),
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
offsets = []
token_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
offsets.append(None)
token_offsets.append(None)
@ -108,26 +110,88 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
all_input_ids = []
max_input_length = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
keep_indices.append(idx)
requests_idx_mapping[r.id] = i
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(
max_input_length, request_input_length
)
# Replace metadata
self.requests_idx_mapping = requests_idx_mapping
self.input_lengths = input_lengths
self.offsets = offsets
self.token_offsets = token_offsets
self.all_input_ids = all_input_ids
self.max_input_length = max_input_length
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self.input_ids = self.input_ids[keep_indices]
self.attention_mask = self.attention_mask[keep_indices]
self.position_ids = self.position_ids[keep_indices]
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
self.past_key_values = [
[
t.view(len(self), -1, *t.shape[-2:])[keep_indices]
for t in layer
]
for layer in self.past_key_values
]
self.requests = [self.requests[i] for i in keep_indices]
self.next_token_choosers = [
self.next_token_choosers[i] for i in keep_indices
]
self.stopping_criterias = [
self.stopping_criterias[i] for i in keep_indices
]
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
@ -136,12 +200,13 @@ class CausalLMBatch(Batch):
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
@ -167,8 +232,14 @@ class CausalLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + batch.size
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
@ -192,17 +263,17 @@ class CausalLMBatch(Batch):
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
batch_left_offset: -batch.padding_right_offset,
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
@ -216,8 +287,8 @@ class CausalLMBatch(Batch):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(batch), -1, *past_values.shape[-2:])
_, num_heads, padded_sequence_length, head_dim = past_values.shape
@ -248,28 +319,29 @@ class CausalLMBatch(Batch):
# We slice the past keys and values to remove the padding from previous batches
if batch.keys_head_dim_last:
past_key_values[j][0][
start_index:end_index,
:,
-(batch.max_input_length - 1) :,
:,
] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
start_index:end_index,
:,
-(batch.max_input_length - 1):,
:,
] = past_keys[:, :, -(batch.max_input_length - 1):, :]
else:
past_key_values[j][0][
start_index:end_index,
:,
:,
-(batch.max_input_length - 1) :,
] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
start_index:end_index,
:,
:,
-(batch.max_input_length - 1):,
] = past_keys[:, :, :, -(batch.max_input_length - 1):]
past_key_values[j][1][
start_index:end_index, :, -(batch.max_input_length - 1) :, :
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
start_index:end_index, :, -(batch.max_input_length - 1):, :
] = past_values[:, :, -(batch.max_input_length - 1):, :]
start_index += batch.size
start_index += len(batch)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -280,7 +352,6 @@ class CausalLMBatch(Batch):
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
@ -292,11 +363,11 @@ class CausalLMBatch(Batch):
class CausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -338,7 +409,7 @@ class CausalLM(Model):
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
@ -352,8 +423,8 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
self, batch: CausalLMBatch
) -> Tuple[List[Generation], CausalLMBatch]:
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
@ -364,19 +435,8 @@ class CausalLM(Model):
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_input_ids = []
next_batch_all_input_ids = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
# Results
generations: List[Generation] = []
@ -395,14 +455,14 @@ class CausalLM(Model):
# For each member of the batch
for i, (
request,
input_length,
offset,
token_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
request,
input_length,
offset,
token_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
@ -429,7 +489,7 @@ class CausalLM(Model):
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
all_input_ids[-stopping_criteria.current_tokens:, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
@ -443,16 +503,6 @@ class CausalLM(Model):
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max(
next_batch_max_input_length, new_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
@ -484,62 +534,25 @@ class CausalLM(Model):
generations.append(generation)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generations, None
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [
[
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
for t in layer
]
for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_position_ids = batch.position_ids
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update values
next_batch_input_ids.append(next_token_id)
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# Decrease right offset
batch.padding_right_offset -= 1
# Create input_ids tensor
batch.input_ids = torch.cat(next_batch_input_ids, dim=0)
# Update attention_mask as we added a new token to input_ids
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
batch.position_ids = batch.position_ids[:, -1:] + 1
next_batch = CausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
attention_mask=next_batch_attention_mask,
position_ids=next_batch_position_ids,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last,
)
return generations, next_batch
# Update past key values
batch.past_key_values = past
return generations, batch

View File

@ -25,6 +25,10 @@ class Batch(ABC):
) -> "Batch":
raise NotImplementedError
@abstractmethod
def filter(self, requests: List[generate_pb2.Request]) -> "Batch":
raise NotImplementedError
@classmethod
@abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch":

View File

@ -60,7 +60,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
batch = batch.filter(batch_pb.requests)
if batch is not None:
batches.append(batch)
if len(batches) == 0:
raise ValueError("All batches are empty")
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches)