This commit is contained in:
OlivierDehaene 2023-06-28 19:26:26 +02:00
parent ae466a8736
commit d649cd8e02
14 changed files with 469 additions and 394 deletions

View File

@ -115,12 +115,6 @@ struct Args {
#[clap(default_value = "1512", long, env)] #[clap(default_value = "1512", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
/// The maximum allowed batch size during dynamic batching.
/// Using `max_batch_total_tokens` should be favored in general
/// as it's a finer way to control RAM usage.
#[clap(long, env)]
max_batch_size: Option<usize>,
/// This represents the ratio of waiting queries vs running queries where /// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting /// you want to start considering pausing the running queries to include the waiting
/// ones into the same batch. /// ones into the same batch.
@ -134,6 +128,9 @@ struct Args {
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
max_batch_prefill_tokens: u32,
/// **IMPORTANT** This is one critical control to allow maximum usage /// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware. /// of the available hardware.
/// ///
@ -181,7 +178,6 @@ struct Args {
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
/// The port to listen on. /// The port to listen on.
port: u16, port: u16,
@ -329,6 +325,12 @@ fn shard_manager(
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
env.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars // Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into())); env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into())); env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
@ -822,6 +824,10 @@ fn spawn_webserver(
args.max_input_length.to_string(), args.max_input_length.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),
args.max_total_tokens.to_string(), args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(), "--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
@ -834,15 +840,6 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Deprecate max_batch_size
if let Some(max_batch_size) = args.max_batch_size {
argv.push("--max-batch-size".to_string());
argv.push(max_batch_size.to_string())
} else {
argv.push("--max-batch-total-tokens".to_string());
argv.push(args.max_batch_total_tokens.to_string())
}
// Model optional revision // Model optional revision
if let Some(ref revision) = args.revision { if let Some(ref revision) = args.revision {
argv.push("--revision".to_string()); argv.push("--revision".to_string());

View File

@ -45,6 +45,7 @@ impl Infer {
client: ShardedClient, client: ShardedClient,
validation: Validation, validation: Validation,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize, max_concurrent_requests: usize,
@ -61,6 +62,7 @@ impl Infer {
tokio::spawn(batching_task( tokio::spawn(batching_task(
client, client,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
queue.clone(), queue.clone(),
@ -243,6 +245,7 @@ impl Infer {
async fn batching_task( async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue, queue: Queue,
@ -257,8 +260,9 @@ async fn batching_task(
// Get the next batch from the queue // Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch, span)) = while let Some((mut entries, batch, span)) = queue
queue.next_batch(None, max_batch_total_tokens).await .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
.await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span) .instrument(span)
@ -287,8 +291,9 @@ async fn batching_task(
let token_budget = max_batch_total_tokens - batch_max_tokens; let token_budget = max_batch_total_tokens - batch_max_tokens;
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = if let Some((mut new_entries, new_batch, span)) = queue
queue.next_batch(min_size, token_budget).await .next_batch(min_size, max_batch_prefill_tokens, token_budget)
.await
{ {
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {

View File

@ -32,11 +32,11 @@ struct Args {
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "1512", long, env)] #[clap(default_value = "1512", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)] #[clap(default_value = "32000", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> {
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
max_batch_size,
waiting_served_ratio, waiting_served_ratio,
mut max_batch_total_tokens, max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
port, port,
master_shard_uds_path, master_shard_uds_path,
@ -141,12 +141,6 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
if let Some(max_batch_size) = max_batch_size {
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
}
if tokenizer.is_none() { if tokenizer.is_none() {
tracing::warn!( tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}" "Could not find a fast tokenizer implementation for {tokenizer_name}"
@ -161,9 +155,15 @@ fn main() -> Result<(), std::io::Error> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| { false => get_model_info(&tokenizer_name, &revision, authorization_token)
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
}), }),
}; };
@ -206,6 +206,7 @@ fn main() -> Result<(), std::io::Error> {
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
sharded_client, sharded_client,

View File

@ -58,6 +58,7 @@ impl Queue {
pub(crate) async fn next_batch( pub(crate) async fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
// Create response channel // Create response channel
@ -67,6 +68,7 @@ impl Queue {
self.queue_sender self.queue_sender
.send(QueueCommand::NextBatch { .send(QueueCommand::NextBatch {
min_size, min_size,
prefill_token_budget,
token_budget, token_budget,
response_sender, response_sender,
span: Span::current(), span: Span::current(),
@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
prefill_token_budget,
token_budget, token_budget,
response_sender, response_sender,
span, span,
} => span.in_scope(|| { } => span.in_scope(|| {
let next_batch = state.next_batch(min_size, token_budget); let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}), }),
@ -140,7 +143,12 @@ impl State {
} }
// Get the next batch // Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> { fn next_batch(
&mut self,
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if self.entries.is_empty() { if self.entries.is_empty() {
return None; return None;
} }
@ -184,7 +192,9 @@ impl State {
decode_tokens += entry.request.stopping_parameters.max_new_tokens; decode_tokens += entry.request.stopping_parameters.max_new_tokens;
if (prefill_tokens + decode_tokens) > token_budget { if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
{
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
@ -259,6 +269,7 @@ enum QueueCommand {
Append(Box<Entry>, Span), Append(Box<Entry>, Span),
NextBatch { NextBatch {
min_size: Option<usize>, min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>, response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span, span: Span,
@ -294,7 +305,7 @@ mod tests {
watermark: false, watermark: false,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: true,
max_new_tokens: 1, max_new_tokens: 1,
stop_sequences: vec![], stop_sequences: vec![],
}, },

View File

@ -152,7 +152,7 @@ async fn generate(
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs); // tracing::debug!("Input: {}", req.0.inputs);
let compute_characters = req.0.inputs.chars().count(); let compute_characters = req.0.inputs.chars().count();
let mut add_prompt = None; let mut add_prompt = None;
@ -286,7 +286,7 @@ async fn generate(
} }
tracing::debug!("Output: {}", output_text); tracing::debug!("Output: {}", output_text);
tracing::info!("Success"); // tracing::info!("Success");
let response = GenerateResponse { let response = GenerateResponse {
generated_text: output_text, generated_text: output_text,
@ -513,6 +513,7 @@ pub async fn run(
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
client: ShardedClient, client: ShardedClient,
@ -581,6 +582,7 @@ pub async fn run(
client, client,
validation, validation,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,

View File

@ -19,10 +19,12 @@ class Cache:
def delete(self, batch_id: int): def delete(self, batch_id: int):
batch = self.pop(batch_id) batch = self.pop(batch_id)
if batch is not None: if batch is not None:
batch.cleanup()
del batch del batch
def clear(self): def clear(self):
self.cache.clear() for k in self.cache.keys():
self.delete(k)
def __len__(self): def __len__(self):
return len(self.cache.keys()) return len(self.cache.keys())

View File

@ -23,7 +23,9 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional from typing import Optional, List, Tuple
from vllm import attention_ops
from vllm import cache_ops
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
@ -106,7 +108,7 @@ class FlashLlamaAttention(torch.nn.Module):
prefix=f"{prefix}.rotary_emb", weights=weights prefix=f"{prefix}.rotary_emb", weights=weights
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size**-0.5
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi( self.query_key_value = TensorParallelColumnLinear.load_multi(
@ -128,14 +130,13 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -144,23 +145,25 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill cache_ops.reshape_and_cache(
if prefill: qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
# Copy to layer past )
layer_past[...] = qkv[:, 1:]
# output # output tensor
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -173,31 +176,18 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0] # kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
# Add present to the layer_past tensor at the correct indices block_size = kv_cache[1].shape[3]
layer_past[past_present_indices] = qkv[:, 1:] attention_ops.single_query_cached_kv_attention(
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
attn_output, attn_output,
start_seq_q, qkv[:, 0],
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -265,14 +255,13 @@ class FlashLlamaLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -281,14 +270,13 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states, normed_hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
# faster post attention rms norm # faster post attention rms norm
@ -333,40 +321,18 @@ class FlashLlamaModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, start_seq_prefill: Optional[torch.Tensor],
end_seq, end_seq_prefill: Optional[torch.Tensor],
start_seq_q, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
end_seq_q, block_tables: torch.Tensor,
max_s, slots: torch.Tensor,
past_present_indices, input_lengths: torch.Tensor,
past_key_values=None, max_s: int,
pre_allocate_past_size: Optional[int] = None, ) -> torch.Tensor:
):
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -380,34 +346,18 @@ class FlashLlamaModel(torch.nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache[i],
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
past_key_values[:, i],
past_present_indices,
prefill,
) )
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, past_key_values return hidden_states
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
@ -423,31 +373,29 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, start_seq_prefill: Optional[torch.Tensor],
end_seq, end_seq_prefill: Optional[torch.Tensor],
start_seq_q, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
end_seq_q, block_tables: torch.Tensor,
max_s, slots: torch.Tensor,
past_present_indices, input_lengths: torch.Tensor,
past_key_values: Optional[torch.Tensor] = None, max_s: int,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
hidden_states, present = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits, present return logits

View File

@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
try: try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError: except RuntimeError:
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False encoder_config.is_decoder = False

View File

@ -1,3 +1,4 @@
import itertools
import torch import torch
import torch.distributed import torch.distributed
@ -5,7 +6,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
@ -20,6 +21,66 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
class CacheManager:
def __init__(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.block_size = 16
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int32
).view(num_blocks, self.block_size)
def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]:
# Number of needed block to cover all tokens
needed_blocks = (n_tokens // self.block_size) + 1
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks"
# Allocate the required number of blocks by setting the mask to 0
block_indices = free_block_indices[:needed_blocks]
self.free_block_mask[block_indices] = 0
# Get slots for the allocated blocks
slots = self.slots[block_indices].flatten()[:n_tokens]
return block_indices.flatten().tolist(), slots
def free(self, block_indices: List[int]):
# Reset mask
self.free_block_mask[block_indices] = 1
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
@ -32,23 +93,20 @@ class FlashCausalLMBatch(Batch):
input_ids: torch.Tensor input_ids: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
# Indices to copy present to the correct indices is the pre-allocated past key values
past_present_indices: torch.Tensor
# tensor of length b holding starting offset of each sequence
start_seq: torch.Tensor
# tensor of length b holding ending offset of each sequence
end_seq: torch.Tensor
# tensor of length b holding starting offset of each sequence, only used in prefill # tensor of length b holding starting offset of each sequence, only used in prefill
start_seq_prefill: Optional[torch.Tensor] start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill # tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor] end_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding starting offset of each query sequence, only used in decode # list of length b of list of length s_i // block_size
start_seq_q: Optional[torch.Tensor] block_tables: List[List[int]]
# tensor of length b holding ending offset of each query sequence, only used in decode # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
end_seq_q: Optional[torch.Tensor] block_tables_tensor: torch.Tensor
# past key values, only used in decode # CPU tensor of length b indicating the start of each sequence in slots
past_key_values: Optional[torch.Tensor] start_slots: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor
max_seqlen: int max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs # Prefill metadata tensors to efficiently compute logprobs
@ -62,6 +120,7 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
input_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]] prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]] read_offsets: List[Optional[int]]
@ -69,15 +128,16 @@ class FlashCausalLMBatch(Batch):
next_token_chooser: HeterogeneousNextTokenChooser next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to # Maximum number of blocks
max_tokens: int max_blocks: int
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
global CACHE_MANAGER
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=len(self.slots),
) )
@classmethod @classmethod
@ -88,6 +148,8 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
global CACHE_MANAGER
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
for r in pb.requests: for r in pb.requests:
@ -99,12 +161,12 @@ class FlashCausalLMBatch(Batch):
)["input_ids"] )["input_ids"]
position_ids = [] position_ids = []
past_present_indices = []
start_seq = []
end_seq = []
start_seq_prefill = [] start_seq_prefill = []
end_seq_prefill = [] end_seq_prefill = []
max_seqlen = 0 block_tables = []
start_slots = []
slots = []
slot_indices = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -126,7 +188,9 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length = 0 cumulative_max_length = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
max_seqlen = 0
max_length = 0 max_length = 0
max_blocks = 0
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
@ -138,7 +202,6 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :] tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input) input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
@ -153,8 +216,6 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
start_seq_prefill.append(cumulative_length) start_seq_prefill.append(cumulative_length)
end_seq_prefill.append(cumulative_length + input_length) end_seq_prefill.append(cumulative_length + input_length)
start_seq.append(cumulative_max_length)
end_seq.append(cumulative_max_length + input_length)
next_token_chooser_parameters.append(r.parameters) next_token_chooser_parameters.append(r.parameters)
@ -164,6 +225,21 @@ class FlashCausalLMBatch(Batch):
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens)
block_tables.append(request_blocks)
slots.extend(request_slots)
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
@ -184,22 +260,26 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
request_past_present_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
past_present_indices.append(request_past_present_indices)
# Update # Update
# Remove one as the first token des not have a past
cumulative_length += input_length cumulative_length += input_length
cumulative_max_length += input_length + max_new_tokens - 1 cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks))
max_length = max(max_length, input_length + max_new_tokens) max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded block tables
block_tables_tensor = torch.zeros(
(len(pb.requests), max_blocks), dtype=torch.int32
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device)
# Padded all_input_ids_tensor # Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros( all_input_ids_tensor = np.zeros(
@ -212,14 +292,15 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = torch.tensor( all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device all_input_ids_tensor, dtype=torch.int64, device=device
) )
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1: if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64) input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
start_seq_prefill = torch.tensor( start_seq_prefill = torch.tensor(
start_seq_prefill, device=device, dtype=torch.int32 start_seq_prefill, device=device, dtype=torch.int32
@ -227,19 +308,13 @@ class FlashCausalLMBatch(Batch):
end_seq_prefill = torch.tensor( end_seq_prefill = torch.tensor(
end_seq_prefill, device=device, dtype=torch.int32 end_seq_prefill, device=device, dtype=torch.int32
) )
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) slots = torch.tensor(slots, dtype=torch.int32, device=device)
past_present_indices = torch.tensor( input_lengths_tensor = torch.tensor(
past_present_indices, device=device, dtype=torch.int64 input_lengths, dtype=torch.int32, device=device
) )
if all_prefill_logprobs: if all_prefill_logprobs:
@ -262,30 +337,31 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=start_seq_prefill, start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill, end_seq_prefill=end_seq_prefill,
start_seq_q=None, block_tables=block_tables,
end_seq_q=None, block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens, prefill_cu_outlens=prefill_cu_outlens,
past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length, max_blocks=max_blocks,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
global CACHE_MANAGER
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same # We assume that if len(requests) == len(self) then the requests are the same
@ -294,28 +370,24 @@ class FlashCausalLMBatch(Batch):
device = self.input_ids.device device = self.input_ids.device
# Cumulative length
cumulative_max_length = 0
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
# Used to index into tensors # Used to index into tensors
indices = [] indices = []
# past indices to keep # slots to keep after filtering
past_indices = torch.zeros( slot_filtering_indices = torch.zeros(
self.past_key_values.shape[0], dtype=torch.bool, device=device self.slots.shape[0], dtype=torch.bool, device=device
) )
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
start_seq = torch.empty(len(request_ids), dtype=torch.int32) slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
start_seq_q = self.start_seq_q[: len(request_ids)]
end_seq_q = self.end_seq_q[: len(request_ids)]
max_seqlen = 0 max_seqlen = 0
requests = [] requests = []
start_slots = []
block_tables = []
all_input_ids = [] all_input_ids = []
input_lengths = [] input_lengths = []
@ -324,6 +396,10 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
max_blocks = 0
# Cumulative length
cumulative_max_length = 0
for i, request_id in enumerate(request_ids): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
indices.append(idx) indices.append(idx)
@ -348,28 +424,45 @@ class FlashCausalLMBatch(Batch):
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
request_block_table = self.block_tables[idx]
block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
# Copy to tensor (CPU) # Copy to tensor (CPU)
start_seq[i] = cumulative_max_length slot_indices[i] = cumulative_max_length + request_input_length - 1
end_seq[i] = cumulative_max_length + request_input_length
# Set slice # Set slice
past_indices[ slot_filtering_indices[
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1 self.start_slots[idx] : self.start_slots[idx]
+ request_input_length
+ remaining_tokens
- 1
] = True ] = True
cumulative_max_length += request_input_length + remaining_tokens - 1 cumulative_max_length += request_input_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table))
# Iterate on all requests
for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch
if r.id not in requests_idx_mapping.keys():
# Free blocks
CACHE_MANAGER.free(self.block_tables[i])
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
past_key_values = self.past_key_values[past_indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
start_seq = start_seq.to(device) slot_indices = slot_indices.to(device)
end_seq = end_seq.to(device)
past_present_indices = end_seq - 1
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -377,51 +470,74 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None, start_seq_prefill=None,
end_seq_prefill=None, end_seq_prefill=None,
start_seq_q=start_seq_q, block_tables=block_tables,
end_seq_q=end_seq_q, block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length, max_blocks=max_blocks,
) )
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
global CACHE_MANAGER
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
total_batch_size = sum([len(b) for b in batches]) total_batch_size = 0
total_slots = 0
dtype = batches[0].past_key_values.dtype max_blocks = 0
device = batches[0].input_ids.device max_length = 0
max_seqlen = 0
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
)
),
)
input_ids = batches[0].input_ids.new_empty(total_batch_size) input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
start_seq = batches[0].start_seq.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots)
end_seq = batches[0].end_seq.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
start_seq_q = torch.arange( input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
0, total_batch_size, device=device, dtype=torch.int32 total_batch_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
) )
end_seq_q = start_seq_q + 1
max_seqlen = 0
past_key_values = []
start_slots = []
block_tables = []
all_input_ids = [] all_input_ids = []
input_lengths = [] input_lengths = []
@ -433,8 +549,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
max_tokens = 0 cumulative_slots = 0
max_length = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
@ -448,16 +563,27 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch) end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU) # Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots
start_seq[start_index:end_index] = batch.start_seq + max_tokens all_input_ids_tensor[
end_seq[start_index:end_index] = batch.end_seq + max_tokens start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
max_seqlen = max(max_seqlen, batch.max_seqlen) block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks]
start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
@ -466,43 +592,17 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
past_key_values.append(batch.past_key_values)
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens cumulative_slots += len(batch.slots)
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
batch.input_lengths, batch.stopping_criterias
)
),
)
past_key_values = torch.cat(past_key_values, dim=0) start_slots = torch.concat(start_slots)
past_present_indices = end_seq - 1
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
cumulative_batch_size = 0
for i, batch in enumerate(batches):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
cumulative_batch_size += len(batch)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
) )
return FlashCausalLMBatch( return FlashCausalLMBatch(
@ -511,28 +611,33 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None, start_seq_prefill=None,
end_seq_prefill=None, end_seq_prefill=None,
start_seq_q=start_seq_q, block_tables=block_tables,
end_seq_q=end_seq_q, block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=max_tokens, max_blocks=max_blocks,
) )
def cleanup(self):
global CACHE_MANAGER
# Free blocks
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -540,32 +645,24 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model): class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model_cls: Type[PreTrainedModel], model: torch.nn.Module,
model_id: str, tokenizer: PreTrainedTokenizerBase,
revision: Optional[str] = None, num_layers: int,
quantize: Optional[str] = None, num_heads: int,
trust_remote_code: bool = False, head_size: int,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
): ):
if torch.cuda.is_available(): self.num_heads = num_heads
device = torch.device("cuda") self.head_size = head_size
dtype = torch.float16
else:
raise NotImplementedError("FlashCausalLM is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained( global CACHE_MANAGER
model_id, torch.cuda.set_per_process_memory_fraction(1.0)
revision=revision, CACHE_MANAGER = CacheManager(
padding_side="left", 1000, num_layers, num_heads, head_size, dtype, device
truncation_side="left",
trust_remote_code=trust_remote_code,
) )
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
).to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model, model=model,
@ -573,6 +670,8 @@ class FlashCausalLM(Model):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@property @property
@ -588,28 +687,27 @@ class FlashCausalLM(Model):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
start_seq: torch.Tensor, start_seq_prefill: Optional[torch.Tensor],
end_seq: torch.Tensor, end_seq_prefill: Optional[torch.Tensor],
start_seq_q: Optional[torch.Tensor], block_tables: torch.Tensor,
end_seq_q: Optional[torch.Tensor], slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int, max_s: int,
past_present_indices: torch.Tensor,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
global CACHE_MANAGER
# Model Forward # Model Forward
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
start_seq=start_seq, start_seq_prefill=start_seq_prefill,
end_seq=end_seq, end_seq_prefill=end_seq_prefill,
start_seq_q=start_seq_q, kv_cache=CACHE_MANAGER.kv_cache,
end_seq_q=end_seq_q, block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s, max_s=max_s,
past_present_indices=past_present_indices,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
) )
@ -617,31 +715,18 @@ class FlashCausalLM(Model):
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None prefill = batch.start_seq_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
if prefill: out = self.forward(
# Ask to pre-allocate kv to its max size
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
pre_allocate_past_size = batch.max_tokens
start_seq = batch.start_seq_prefill
end_seq = batch.end_seq_prefill
else:
pre_allocate_past_size = None
start_seq = batch.start_seq
end_seq = batch.end_seq
out, present = self.forward(
batch.input_ids, batch.input_ids,
batch.position_ids, batch.position_ids,
start_seq, batch.start_seq_prefill,
end_seq, batch.end_seq_prefill,
batch.start_seq_q, batch.block_tables_tensor,
batch.end_seq_q, batch.slots[batch.slot_indices],
batch.input_lengths_tensor,
batch.max_seqlen, batch.max_seqlen,
batch.past_present_indices,
batch.past_key_values,
pre_allocate_past_size,
batch.prefill_head_indices, batch.prefill_head_indices,
) )
@ -662,12 +747,8 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.start_seq_q and batch.end_seq_q for decode
batch.start_seq_q = torch.arange(
0, len(batch), device=self.device, dtype=torch.int32
)
batch.end_seq_q = batch.start_seq_q + 1
next_position_ids = batch.position_ids.new_empty(len(batch)) next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1]
# We do not need start_seq_prefill and end_seq_prefill anymore # We do not need start_seq_prefill and end_seq_prefill anymore
batch.start_seq_prefill = None batch.start_seq_prefill = None
batch.end_seq_prefill = None batch.end_seq_prefill = None
@ -731,8 +812,8 @@ class FlashCausalLM(Model):
# Set values in batch # Set values in batch
batch.input_ids = next_input_ids batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1 batch.position_ids = next_position_ids + 1
batch.past_present_indices = batch.end_seq batch.input_lengths_tensor += 1
batch.end_seq = batch.end_seq + 1 batch.slot_indices += 1
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
@ -755,7 +836,6 @@ class FlashCausalLM(Model):
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.all_input_ids_tensor,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
next_token_ids, next_token_ids,
@ -770,7 +850,6 @@ class FlashCausalLM(Model):
read_offset, read_offset,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
all_input_ids_tensor,
do_sample, do_sample,
seed, seed,
next_token_id, next_token_id,
@ -845,19 +924,20 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
new_input_length = input_length + 1
# Update values # Update values
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = input_length + 1
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
if stopped:
batch.cleanup()
# No need to return a batch if we know that all requests stopped
return generations, None
batch.prefill_cu_outlens = None batch.prefill_cu_outlens = None
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1 batch.max_seqlen = batch.max_seqlen + 1
batch.past_key_values = present
# No need to return a batch if we know that all requests stopped return generations, batch
return generations, batch if not stopped else None

View File

@ -64,10 +64,12 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, num_layers=len(model.model.layers),
num_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,

View File

@ -52,8 +52,11 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group, filenames,
aliases = {"transformer.wte.weight": ["lm_head.weight"]} device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
) )
model = FlashSantacoderForCausalLM(config, weights) model = FlashSantacoderForCausalLM(config, weights)

View File

@ -35,6 +35,9 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError raise NotImplementedError
def cleanup(self):
pass
@abstractmethod @abstractmethod
def __len__(self): def __len__(self):
raise NotImplementedError raise NotImplementedError

View File

@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser:
self.seeds = seeds self.seeds = seeds
self.do_sample = do_sample self.do_sample = do_sample
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None: if self.watermark_processor is not None:

View File

@ -5,7 +5,14 @@ import torch
class Weights: class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): def __init__(
self,
filenames: List[Path],
device,
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
):
routing = {} routing = {}
for filename in filenames: for filename in filenames:
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
@ -43,7 +50,7 @@ class Weights:
return str(filename), tensor_name return str(filename), tensor_name
def _get_slice(self, tensor_name: str): def _get_slice(self, tensor_name: str):
filename, tensor_name= self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
return slice_ return slice_
@ -94,12 +101,20 @@ class Weights:
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq": if quantize == "gptq":
try: try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError: except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) qzeros = torch.cat(
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: for w2 in w[1:]:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
@ -118,7 +133,9 @@ class Weights:
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = self.get_tensor(f"{prefix}.qzeros") qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)