This commit is contained in:
OlivierDehaene 2023-04-14 12:33:44 +02:00
parent b6ee0ec7b0
commit 4e63d9cb28
5 changed files with 237 additions and 182 deletions

View File

@ -132,7 +132,7 @@ message PrefillResponse {
/// Generation /// Generation
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; Batch batch = 2;
} }
message DecodeRequest { message DecodeRequest {
@ -144,5 +144,5 @@ message DecodeResponse {
/// Decodes /// Decodes
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; Batch batch = 2;
} }

View File

@ -7,9 +7,8 @@ use futures::future::try_join_all;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use flume::SendError;
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{Notify, Semaphore, TryAcquireError}; use tokio::sync::{Notify, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
@ -339,7 +338,21 @@ async fn prefill(
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
send_generations(generations, entries); filter_send_generations(generations, entries);
let next_batch = {
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
let size = batch.requests.len();
if size == 0 {
let _ = client.clear_cache(Some(batch.id)).await;
return None;
}
batch.size = size as u32;
Some(batch)
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch next_batch
@ -361,17 +374,35 @@ async fn decode(
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> { ) -> Option<Batch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
send_generations(generations, entries); filter_send_generations(generations, entries);
let next_batch = {
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
let size = batch.requests.len();
if size == 0 {
let _ = client.clear_cache(Some(batch.id)).await;
return None;
}
batch.size = size as u32;
Some(batch)
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None None
@ -398,64 +429,66 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
} }
/// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)] #[instrument(skip_all)]
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) { fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| { generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry // Get entry
// We can `expect` here as the request id should always be in the entries // We can `expect` here as the request id should always be in the entries
let entry = entries let entry = entries
.get(&generation.request_id) .get(&id)
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation back to infer task
if let Some(prefill_tokens) = generation.prefill_tokens { // If the receive an error from the Flume channel, we need to stop generating for this
// Send message // request hence why we unwrap_or(true)
// unwrap_or is valid here as we don't care if the receiver is gone. let stopped = send_generation(generation, entry).unwrap_or(true);
entry if stopped {
.response_tx entries.remove(&id).expect("ID not found in entries. This is a bug.");
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
.unwrap_or(());
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let entry = entries
.remove(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))
.unwrap_or(());
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
} }
}); });
} }
fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
entry.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
// Send message
entry.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry.response_tx
.send(Ok(InferStreamResponse::Token(token)))
?;
}
Ok(stopped)
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
class CausalLMBatch(Batch): class CausalLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: torch.Tensor
@ -42,7 +43,6 @@ class CausalLMBatch(Batch):
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Metadata used for padding # Metadata used for padding
size: int
max_input_length: int max_input_length: int
padding_right_offset: int padding_right_offset: int
@ -53,26 +53,28 @@ class CausalLMBatch(Batch):
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
size=self.size, size=len(self),
) )
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
offsets = [] offsets = []
token_offsets = [] token_offsets = []
requests_idx_mapping = {}
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
offsets.append(None) offsets.append(None)
token_offsets.append(None) token_offsets.append(None)
@ -108,26 +110,88 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
offsets=offsets, offsets=offsets,
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
all_input_ids = []
max_input_length = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
keep_indices.append(idx)
requests_idx_mapping[r.id] = i
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(
max_input_length, request_input_length
)
# Replace metadata
self.requests_idx_mapping = requests_idx_mapping
self.input_lengths = input_lengths
self.offsets = offsets
self.token_offsets = token_offsets
self.all_input_ids = all_input_ids
self.max_input_length = max_input_length
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self.input_ids = self.input_ids[keep_indices]
self.attention_mask = self.attention_mask[keep_indices]
self.position_ids = self.position_ids[keep_indices]
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
self.past_key_values = [
[
t.view(len(self), -1, *t.shape[-2:])[keep_indices]
for t in layer
]
for layer in self.past_key_values
]
self.requests = [self.requests[i] for i in keep_indices]
self.next_token_choosers = [
self.next_token_choosers[i] for i in keep_indices
]
self.stopping_criterias = [
self.stopping_criterias[i] for i in keep_indices
]
return self
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
@ -136,12 +200,13 @@ class CausalLMBatch(Batch):
max_input_length = 0 max_input_length = 0
padding_right_offset = 0 padding_right_offset = 0
for batch in batches: for batch in batches:
total_batch_size += batch.size total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length) max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset) padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {}
input_lengths = [] input_lengths = []
offsets = [] offsets = []
token_offsets = [] token_offsets = []
@ -167,8 +232,14 @@ class CausalLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch # Slicing end index for this batch
end_index = start_index + batch.size end_index = start_index + len(batch)
# We only concatenate batches that did at least one step # We only concatenate batches that did at least one step
if batch.past_key_values is None: if batch.past_key_values is None:
@ -192,17 +263,17 @@ class CausalLMBatch(Batch):
# and to remove unused allocated space # and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length left_offset = max_input_length - batch.max_input_length
batch_left_offset = ( batch_left_offset = (
batch.attention_mask.shape[1] batch.attention_mask.shape[1]
- batch.max_input_length - batch.max_input_length
- batch.padding_right_offset - batch.padding_right_offset
) )
attention_mask[ attention_mask[
start_index:end_index, start_index:end_index,
left_offset:-padding_right_offset, left_offset:-padding_right_offset,
] = batch.attention_mask[ ] = batch.attention_mask[
:, :,
batch_left_offset : -batch.padding_right_offset, batch_left_offset: -batch.padding_right_offset,
] ]
# Create empty tensor # Create empty tensor
# position_ids is always of shape [batch_size, 1] # position_ids is always of shape [batch_size, 1]
@ -216,8 +287,8 @@ class CausalLMBatch(Batch):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:]) past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:])
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:]) past_values = past_values.view(len(batch), -1, *past_values.shape[-2:])
_, num_heads, padded_sequence_length, head_dim = past_values.shape _, num_heads, padded_sequence_length, head_dim = past_values.shape
@ -248,28 +319,29 @@ class CausalLMBatch(Batch):
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_input_length - 1) :, -(batch.max_input_length - 1):,
:, :,
] = past_keys[:, :, -(batch.max_input_length - 1) :, :] ] = past_keys[:, :, -(batch.max_input_length - 1):, :]
else: else:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
:, :,
-(batch.max_input_length - 1) :, -(batch.max_input_length - 1):,
] = past_keys[:, :, :, -(batch.max_input_length - 1) :] ] = past_keys[:, :, :, -(batch.max_input_length - 1):]
past_key_values[j][1][ past_key_values[j][1][
start_index:end_index, :, -(batch.max_input_length - 1) :, : start_index:end_index, :, -(batch.max_input_length - 1):, :
] = past_values[:, :, -(batch.max_input_length - 1) :, :] ] = past_values[:, :, -(batch.max_input_length - 1):, :]
start_index += batch.size start_index += len(batch)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -280,7 +352,6 @@ class CausalLMBatch(Batch):
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
@ -292,11 +363,11 @@ class CausalLMBatch(Batch):
class CausalLM(Model): class CausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: bool = False, quantize: bool = False,
decode_buffer: int = 3, decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -338,7 +409,7 @@ class CausalLM(Model):
) )
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs = self.model.forward(
@ -352,8 +423,8 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], CausalLMBatch]:
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
@ -364,19 +435,8 @@ class CausalLM(Model):
batch.past_key_values, batch.past_key_values,
) )
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward # New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_input_ids = [] next_batch_input_ids = []
next_batch_all_input_ids = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
@ -395,14 +455,14 @@ class CausalLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length, input_length,
offset, offset,
token_offset, token_offset,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
@ -429,7 +489,7 @@ class CausalLM(Model):
if stop: if stop:
# Decode generated tokens # Decode generated tokens
output_text = self.decode( output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0] all_input_ids[-stopping_criteria.current_tokens:, 0]
) )
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -443,16 +503,6 @@ class CausalLM(Model):
else: else:
# Keep request in the batch # Keep request in the batch
generated_text = None generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max(
next_batch_max_input_length, new_input_length
)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
@ -484,62 +534,25 @@ class CausalLM(Model):
generations.append(generation) generations.append(generation)
# We finished all generations in the batch; there is no next batch # Update values
if not next_batch_keep_indices: next_batch_input_ids.append(next_token_id)
return generations, None batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) batch.offsets[i] = offset
# If we finished at least one generation, we need to evict the indices of the generations that finished batch.token_offsets[i] = token_offset
# from the values of the next batch batch.max_input_length = max(batch.max_input_length, new_input_length)
if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [
[
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
for t in layer
]
for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_position_ids = batch.position_ids
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Decrease right offset
batch.padding_right_offset -= 1
# Create input_ids tensor
batch.input_ids = torch.cat(next_batch_input_ids, dim=0)
# Update attention_mask as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
next_batch_attention_mask[:, -batch.padding_right_offset] = 1 batch.attention_mask[:, -batch.padding_right_offset] = 1
# Update position_ids # Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1 batch.position_ids = batch.position_ids[:, -1:] + 1
next_batch = CausalLMBatch( # Update past key values
batch_id=batch.batch_id, batch.past_key_values = past
requests=next_batch_requests,
input_ids=next_batch_input_ids, return generations, batch
attention_mask=next_batch_attention_mask,
position_ids=next_batch_position_ids,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last,
)
return generations, next_batch

View File

@ -25,6 +25,10 @@ class Batch(ABC):
) -> "Batch": ) -> "Batch":
raise NotImplementedError raise NotImplementedError
@abstractmethod
def filter(self, requests: List[generate_pb2.Request]) -> "Batch":
raise NotImplementedError
@classmethod @classmethod
@abstractmethod @abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":

View File

@ -60,7 +60,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(batch_pb.id) batch = self.cache.pop(batch_pb.id)
if batch is None: if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch) batch = batch.filter(batch_pb.requests)
if batch is not None:
batches.append(batch)
if len(batches) == 0:
raise ValueError("All batches are empty")
if len(batches) > 1: if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches) batch = self.model.batch_type.concatenate(batches)