mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
wip
This commit is contained in:
parent
ae466a8736
commit
d649cd8e02
@ -115,12 +115,6 @@ struct Args {
|
|||||||
#[clap(default_value = "1512", long, env)]
|
#[clap(default_value = "1512", long, env)]
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
|
||||||
/// The maximum allowed batch size during dynamic batching.
|
|
||||||
/// Using `max_batch_total_tokens` should be favored in general
|
|
||||||
/// as it's a finer way to control RAM usage.
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
|
|
||||||
/// This represents the ratio of waiting queries vs running queries where
|
/// This represents the ratio of waiting queries vs running queries where
|
||||||
/// you want to start considering pausing the running queries to include the waiting
|
/// you want to start considering pausing the running queries to include the waiting
|
||||||
/// ones into the same batch.
|
/// ones into the same batch.
|
||||||
@ -134,6 +128,9 @@ struct Args {
|
|||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
|
||||||
|
#[clap(default_value = "32000", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
|
||||||
/// **IMPORTANT** This is one critical control to allow maximum usage
|
/// **IMPORTANT** This is one critical control to allow maximum usage
|
||||||
/// of the available hardware.
|
/// of the available hardware.
|
||||||
///
|
///
|
||||||
@ -181,7 +178,6 @@ struct Args {
|
|||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
|
||||||
/// The port to listen on.
|
/// The port to listen on.
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
@ -329,6 +325,12 @@ fn shard_manager(
|
|||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
|
// Use cuda allocator. It leads to less memory fragmentation
|
||||||
|
env.push((
|
||||||
|
"PYTORCH_CUDA_ALLOC_CONF".into(),
|
||||||
|
"backend:cudaMallocAsync".into(),
|
||||||
|
));
|
||||||
|
|
||||||
// Torch Distributed Env vars
|
// Torch Distributed Env vars
|
||||||
env.push(("RANK".into(), rank.to_string().into()));
|
env.push(("RANK".into(), rank.to_string().into()));
|
||||||
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||||
@ -822,6 +824,10 @@ fn spawn_webserver(
|
|||||||
args.max_input_length.to_string(),
|
args.max_input_length.to_string(),
|
||||||
"--max-total-tokens".to_string(),
|
"--max-total-tokens".to_string(),
|
||||||
args.max_total_tokens.to_string(),
|
args.max_total_tokens.to_string(),
|
||||||
|
"--max-batch-prefill-tokens".to_string(),
|
||||||
|
args.max_batch_prefill_tokens.to_string(),
|
||||||
|
"--max-batch-total-tokens".to_string(),
|
||||||
|
args.max_batch_total_tokens.to_string(),
|
||||||
"--waiting-served-ratio".to_string(),
|
"--waiting-served-ratio".to_string(),
|
||||||
args.waiting_served_ratio.to_string(),
|
args.waiting_served_ratio.to_string(),
|
||||||
"--max-waiting-tokens".to_string(),
|
"--max-waiting-tokens".to_string(),
|
||||||
@ -834,15 +840,6 @@ fn spawn_webserver(
|
|||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
// Deprecate max_batch_size
|
|
||||||
if let Some(max_batch_size) = args.max_batch_size {
|
|
||||||
argv.push("--max-batch-size".to_string());
|
|
||||||
argv.push(max_batch_size.to_string())
|
|
||||||
} else {
|
|
||||||
argv.push("--max-batch-total-tokens".to_string());
|
|
||||||
argv.push(args.max_batch_total_tokens.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(ref revision) = args.revision {
|
if let Some(ref revision) = args.revision {
|
||||||
argv.push("--revision".to_string());
|
argv.push("--revision".to_string());
|
||||||
|
@ -45,6 +45,7 @@ impl Infer {
|
|||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
@ -61,6 +62,7 @@ impl Infer {
|
|||||||
tokio::spawn(batching_task(
|
tokio::spawn(batching_task(
|
||||||
client,
|
client,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
@ -243,6 +245,7 @@ impl Infer {
|
|||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
@ -257,8 +260,9 @@ async fn batching_task(
|
|||||||
// Get the next batch from the queue
|
// Get the next batch from the queue
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
while let Some((mut entries, batch, span)) =
|
while let Some((mut entries, batch, span)) = queue
|
||||||
queue.next_batch(None, max_batch_total_tokens).await
|
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
||||||
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
@ -287,8 +291,9 @@ async fn batching_task(
|
|||||||
let token_budget = max_batch_total_tokens - batch_max_tokens;
|
let token_budget = max_batch_total_tokens - batch_max_tokens;
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) =
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
queue.next_batch(min_size, token_budget).await
|
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
|
||||||
|
.await
|
||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
|
@ -32,11 +32,11 @@ struct Args {
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
#[clap(default_value = "1512", long, env)]
|
#[clap(default_value = "1512", long, env)]
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "32000", long, env)]
|
#[clap(default_value = "32000", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(default_value = "32000", long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
mut max_batch_total_tokens,
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
@ -141,12 +141,6 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
if let Some(max_batch_size) = max_batch_size {
|
|
||||||
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
|
|
||||||
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
|
|
||||||
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
||||||
@ -161,10 +155,16 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
},
|
},
|
||||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
|
false => get_model_info(&tokenizer_name, &revision, authorization_token)
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
.await
|
||||||
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
.unwrap_or_else(|| {
|
||||||
}),
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
HubModelInfo {
|
||||||
|
model_id: tokenizer_name.to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
}
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||||
@ -206,6 +206,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
@ -219,7 +220,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
ngrok_username,
|
ngrok_username,
|
||||||
ngrok_password,
|
ngrok_password,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -58,6 +58,7 @@ impl Queue {
|
|||||||
pub(crate) async fn next_batch(
|
pub(crate) async fn next_batch(
|
||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -67,6 +68,7 @@ impl Queue {
|
|||||||
self.queue_sender
|
self.queue_sender
|
||||||
.send(QueueCommand::NextBatch {
|
.send(QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
|
prefill_token_budget,
|
||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
span: Span::current(),
|
span: Span::current(),
|
||||||
@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
|
|||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
|
prefill_token_budget,
|
||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
span,
|
span,
|
||||||
} => span.in_scope(|| {
|
} => span.in_scope(|| {
|
||||||
let next_batch = state.next_batch(min_size, token_budget);
|
let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||||
}),
|
}),
|
||||||
@ -140,7 +143,12 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
fn next_batch(
|
||||||
|
&mut self,
|
||||||
|
min_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
|
token_budget: u32,
|
||||||
|
) -> Option<NextBatch> {
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@ -184,7 +192,9 @@ impl State {
|
|||||||
|
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
|
|
||||||
if (prefill_tokens + decode_tokens) > token_budget {
|
if prefill_tokens > prefill_token_budget
|
||||||
|
|| (prefill_tokens + decode_tokens) > token_budget
|
||||||
|
{
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
@ -259,6 +269,7 @@ enum QueueCommand {
|
|||||||
Append(Box<Entry>, Span),
|
Append(Box<Entry>, Span),
|
||||||
NextBatch {
|
NextBatch {
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||||
span: Span,
|
span: Span,
|
||||||
@ -294,7 +305,7 @@ mod tests {
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: true,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
|
@ -152,7 +152,7 @@ async fn generate(
|
|||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
tracing::debug!("Input: {}", req.0.inputs);
|
// tracing::debug!("Input: {}", req.0.inputs);
|
||||||
|
|
||||||
let compute_characters = req.0.inputs.chars().count();
|
let compute_characters = req.0.inputs.chars().count();
|
||||||
let mut add_prompt = None;
|
let mut add_prompt = None;
|
||||||
@ -286,7 +286,7 @@ async fn generate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
tracing::debug!("Output: {}", output_text);
|
tracing::debug!("Output: {}", output_text);
|
||||||
tracing::info!("Success");
|
// tracing::info!("Success");
|
||||||
|
|
||||||
let response = GenerateResponse {
|
let response = GenerateResponse {
|
||||||
generated_text: output_text,
|
generated_text: output_text,
|
||||||
@ -513,6 +513,7 @@ pub async fn run(
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
@ -581,6 +582,7 @@ pub async fn run(
|
|||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
@ -19,10 +19,12 @@ class Cache:
|
|||||||
def delete(self, batch_id: int):
|
def delete(self, batch_id: int):
|
||||||
batch = self.pop(batch_id)
|
batch = self.pop(batch_id)
|
||||||
if batch is not None:
|
if batch is not None:
|
||||||
|
batch.cleanup()
|
||||||
del batch
|
del batch
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.cache.clear()
|
for k in self.cache.keys():
|
||||||
|
self.delete(k)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.cache.keys())
|
return len(self.cache.keys())
|
||||||
|
@ -23,7 +23,9 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
|
from vllm import attention_ops
|
||||||
|
from vllm import cache_ops
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
@ -106,7 +108,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
@ -128,14 +130,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
@ -144,23 +145,25 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
cache_ops.reshape_and_cache(
|
||||||
if prefill:
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
# Copy to layer past
|
)
|
||||||
layer_past[...] = qkv[:, 1:]
|
|
||||||
|
|
||||||
# output
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if start_seq_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -173,31 +176,18 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
block_size = kv_cache[1].shape[3]
|
||||||
layer_past[past_present_indices] = qkv[:, 1:]
|
attention_ops.single_query_cached_kv_attention(
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
layer_past[:, 0],
|
|
||||||
layer_past[:, 1],
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
qkv[:, 0],
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -265,14 +255,13 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -281,14 +270,13 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
normed_hidden_states,
|
normed_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -333,40 +321,18 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
end_seq_q,
|
block_tables: torch.Tensor,
|
||||||
max_s,
|
slots: torch.Tensor,
|
||||||
past_present_indices,
|
input_lengths: torch.Tensor,
|
||||||
past_key_values=None,
|
max_s: int,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
) -> torch.Tensor:
|
||||||
):
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if past_key_values is None:
|
|
||||||
assert pre_allocate_past_size is not None
|
|
||||||
|
|
||||||
prefill = True
|
|
||||||
|
|
||||||
# Create past tensor
|
|
||||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
len(input_ids),
|
|
||||||
len(self.layers),
|
|
||||||
2,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Decode
|
|
||||||
else:
|
|
||||||
prefill = False
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
@ -380,34 +346,18 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache[i],
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[:, i],
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
|
||||||
present = past_key_values
|
|
||||||
# Create padded past tensor
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
pre_allocate_past_size,
|
|
||||||
len(self.layers),
|
|
||||||
2,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# We slice only once instead of at every layer
|
|
||||||
past_key_values[past_present_indices] = present
|
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
@ -423,31 +373,29 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
end_seq_q,
|
block_tables: torch.Tensor,
|
||||||
max_s,
|
slots: torch.Tensor,
|
||||||
past_present_indices,
|
input_lengths: torch.Tensor,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
max_s: int,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
hidden_states, present = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
|
||||||
past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits, present
|
return logits
|
||||||
|
@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
try:
|
try:
|
||||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
|
self.shared = TensorParallelEmbedding(
|
||||||
|
prefix="encoder.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
encoder_config = copy.deepcopy(config)
|
encoder_config = copy.deepcopy(config)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import itertools
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -5,7 +6,7 @@ import numpy as np
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -20,6 +21,66 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
# Will be set in warmup
|
||||||
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_blocks: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
self.block_size = 16
|
||||||
|
|
||||||
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
x = self.block_size // element_size
|
||||||
|
|
||||||
|
self.kv_cache = [
|
||||||
|
(
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size, self.block_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||||
|
self.slots = torch.arange(
|
||||||
|
0, num_blocks * self.block_size, dtype=torch.int32
|
||||||
|
).view(num_blocks, self.block_size)
|
||||||
|
|
||||||
|
def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]:
|
||||||
|
# Number of needed block to cover all tokens
|
||||||
|
needed_blocks = (n_tokens // self.block_size) + 1
|
||||||
|
|
||||||
|
# Get free blocks indices by finding values in mask that are not set to 0
|
||||||
|
free_block_indices = self.free_block_mask.nonzero()
|
||||||
|
assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks"
|
||||||
|
|
||||||
|
# Allocate the required number of blocks by setting the mask to 0
|
||||||
|
block_indices = free_block_indices[:needed_blocks]
|
||||||
|
self.free_block_mask[block_indices] = 0
|
||||||
|
|
||||||
|
# Get slots for the allocated blocks
|
||||||
|
slots = self.slots[block_indices].flatten()[:n_tokens]
|
||||||
|
|
||||||
|
return block_indices.flatten().tolist(), slots
|
||||||
|
|
||||||
|
def free(self, block_indices: List[int]):
|
||||||
|
# Reset mask
|
||||||
|
self.free_block_mask[block_indices] = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
@ -32,23 +93,20 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
|
|
||||||
# Indices to copy present to the correct indices is the pre-allocated past key values
|
|
||||||
past_present_indices: torch.Tensor
|
|
||||||
|
|
||||||
# tensor of length b holding starting offset of each sequence
|
|
||||||
start_seq: torch.Tensor
|
|
||||||
# tensor of length b holding ending offset of each sequence
|
|
||||||
end_seq: torch.Tensor
|
|
||||||
# tensor of length b holding starting offset of each sequence, only used in prefill
|
# tensor of length b holding starting offset of each sequence, only used in prefill
|
||||||
start_seq_prefill: Optional[torch.Tensor]
|
start_seq_prefill: Optional[torch.Tensor]
|
||||||
# tensor of length b holding ending offset of each sequence, only used in prefill
|
# tensor of length b holding ending offset of each sequence, only used in prefill
|
||||||
end_seq_prefill: Optional[torch.Tensor]
|
end_seq_prefill: Optional[torch.Tensor]
|
||||||
# tensor of length b holding starting offset of each query sequence, only used in decode
|
# list of length b of list of length s_i // block_size
|
||||||
start_seq_q: Optional[torch.Tensor]
|
block_tables: List[List[int]]
|
||||||
# tensor of length b holding ending offset of each query sequence, only used in decode
|
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
end_seq_q: Optional[torch.Tensor]
|
block_tables_tensor: torch.Tensor
|
||||||
# past key values, only used in decode
|
# CPU tensor of length b indicating the start of each sequence in slots
|
||||||
past_key_values: Optional[torch.Tensor]
|
start_slots: torch.Tensor
|
||||||
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
|
slots: torch.Tensor
|
||||||
|
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||||
|
slot_indices: torch.Tensor
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
# Prefill metadata tensors to efficiently compute logprobs
|
# Prefill metadata tensors to efficiently compute logprobs
|
||||||
@ -62,6 +120,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
|
input_lengths_tensor: torch.Tensor
|
||||||
prefix_offsets: List[Optional[int]]
|
prefix_offsets: List[Optional[int]]
|
||||||
read_offsets: List[Optional[int]]
|
read_offsets: List[Optional[int]]
|
||||||
|
|
||||||
@ -69,15 +128,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser: HeterogeneousNextTokenChooser
|
next_token_chooser: HeterogeneousNextTokenChooser
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
# Maximum number of tokens this batch will grow to
|
# Maximum number of blocks
|
||||||
max_tokens: int
|
max_blocks: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
|
global CACHE_MANAGER
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=len(self.slots),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -88,6 +148,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
|
global CACHE_MANAGER
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
@ -99,12 +161,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
past_present_indices = []
|
|
||||||
start_seq = []
|
|
||||||
end_seq = []
|
|
||||||
start_seq_prefill = []
|
start_seq_prefill = []
|
||||||
end_seq_prefill = []
|
end_seq_prefill = []
|
||||||
max_seqlen = 0
|
block_tables = []
|
||||||
|
start_slots = []
|
||||||
|
slots = []
|
||||||
|
slot_indices = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
@ -126,7 +188,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
|
max_seqlen = 0
|
||||||
max_length = 0
|
max_length = 0
|
||||||
|
max_blocks = 0
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
@ -138,7 +202,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(input_length - 5)
|
||||||
@ -153,8 +216,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
start_seq_prefill.append(cumulative_length)
|
start_seq_prefill.append(cumulative_length)
|
||||||
end_seq_prefill.append(cumulative_length + input_length)
|
end_seq_prefill.append(cumulative_length + input_length)
|
||||||
start_seq.append(cumulative_max_length)
|
|
||||||
end_seq.append(cumulative_max_length + input_length)
|
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
@ -164,6 +225,21 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
|
# Paged attention
|
||||||
|
# Remove one as the first token des not have a past
|
||||||
|
total_tokens = input_length + max_new_tokens - 1
|
||||||
|
request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens)
|
||||||
|
block_tables.append(request_blocks)
|
||||||
|
slots.extend(request_slots)
|
||||||
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
|
request_slot_indices = torch.arange(
|
||||||
|
cumulative_max_length,
|
||||||
|
cumulative_max_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
|
|
||||||
@ -184,22 +260,26 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
request_past_present_indices = torch.arange(
|
|
||||||
cumulative_max_length,
|
|
||||||
cumulative_max_length + input_length,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
past_present_indices.append(request_past_present_indices)
|
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
# Remove one as the first token des not have a past
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += input_length + max_new_tokens - 1
|
cumulative_max_length += total_tokens
|
||||||
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(max_length, input_length + max_new_tokens)
|
max_length = max(max_length, input_length + max_new_tokens)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device
|
||||||
)
|
)
|
||||||
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
# Padded block tables
|
||||||
|
block_tables_tensor = torch.zeros(
|
||||||
|
(len(pb.requests), max_blocks), dtype=torch.int32
|
||||||
|
)
|
||||||
|
for i, request_blocks in enumerate(block_tables):
|
||||||
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
|
|
||||||
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
|
||||||
# Padded all_input_ids_tensor
|
# Padded all_input_ids_tensor
|
||||||
all_input_ids_tensor = np.zeros(
|
all_input_ids_tensor = np.zeros(
|
||||||
@ -212,34 +292,29 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor = torch.tensor(
|
all_input_ids_tensor = torch.tensor(
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
|
||||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
if len(pb.requests) > 1:
|
||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
|
slot_indices = torch.cat(slot_indices)
|
||||||
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
|
|
||||||
|
|
||||||
start_seq_prefill = torch.tensor(
|
|
||||||
start_seq_prefill, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
end_seq_prefill = torch.tensor(
|
|
||||||
end_seq_prefill, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_input_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
|
slot_indices = slot_indices[0]
|
||||||
|
|
||||||
past_present_indices = past_present_indices[0]
|
start_seq_prefill = torch.tensor(
|
||||||
|
start_seq_prefill, device=device, dtype=torch.int32
|
||||||
start_seq_prefill = start_seq
|
)
|
||||||
end_seq_prefill = end_seq
|
end_seq_prefill = torch.tensor(
|
||||||
|
end_seq_prefill, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids = position_ids.to(device)
|
||||||
|
slot_indices = slot_indices.to(device)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
slots = torch.tensor(slots, dtype=torch.int32, device=device)
|
||||||
past_present_indices = torch.tensor(
|
input_lengths_tensor = torch.tensor(
|
||||||
past_present_indices, device=device, dtype=torch.int64
|
input_lengths, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
@ -262,30 +337,31 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_present_indices=past_present_indices,
|
|
||||||
start_seq=start_seq,
|
|
||||||
end_seq=end_seq,
|
|
||||||
start_seq_prefill=start_seq_prefill,
|
start_seq_prefill=start_seq_prefill,
|
||||||
end_seq_prefill=end_seq_prefill,
|
end_seq_prefill=end_seq_prefill,
|
||||||
start_seq_q=None,
|
block_tables=block_tables,
|
||||||
end_seq_q=None,
|
block_tables_tensor=block_tables_tensor,
|
||||||
|
start_slots=start_slots,
|
||||||
|
slots=slots,
|
||||||
|
slot_indices=slot_indices,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
prefill_cu_outlens=prefill_cu_outlens,
|
prefill_cu_outlens=prefill_cu_outlens,
|
||||||
past_key_values=None,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=cumulative_max_length,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||||
|
global CACHE_MANAGER
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
# We assume that if len(requests) == len(self) then the requests are the same
|
||||||
@ -294,28 +370,24 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
|
|
||||||
# Cumulative length
|
|
||||||
cumulative_max_length = 0
|
|
||||||
|
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Used to index into tensors
|
# Used to index into tensors
|
||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
# past indices to keep
|
# slots to keep after filtering
|
||||||
past_indices = torch.zeros(
|
slot_filtering_indices = torch.zeros(
|
||||||
self.past_key_values.shape[0], dtype=torch.bool, device=device
|
self.slots.shape[0], dtype=torch.bool, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||||
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
|
||||||
start_seq_q = self.start_seq_q[: len(request_ids)]
|
|
||||||
end_seq_q = self.end_seq_q[: len(request_ids)]
|
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
|
start_slots = []
|
||||||
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
@ -324,6 +396,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
|
max_blocks = 0
|
||||||
|
# Cumulative length
|
||||||
|
cumulative_max_length = 0
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
indices.append(idx)
|
indices.append(idx)
|
||||||
@ -348,28 +424,45 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request_block_table = self.block_tables[idx]
|
||||||
|
block_tables.append(request_block_table)
|
||||||
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
# Copy to tensor (CPU)
|
||||||
start_seq[i] = cumulative_max_length
|
slot_indices[i] = cumulative_max_length + request_input_length - 1
|
||||||
end_seq[i] = cumulative_max_length + request_input_length
|
|
||||||
|
|
||||||
# Set slice
|
# Set slice
|
||||||
past_indices[
|
slot_filtering_indices[
|
||||||
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
|
self.start_slots[idx] : self.start_slots[idx]
|
||||||
|
+ request_input_length
|
||||||
|
+ remaining_tokens
|
||||||
|
- 1
|
||||||
] = True
|
] = True
|
||||||
|
|
||||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||||
|
|
||||||
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
|
# Iterate on all requests
|
||||||
|
for i, r in enumerate(self.requests):
|
||||||
|
# Filter requests that are not part of the new batch
|
||||||
|
if r.id not in requests_idx_mapping.keys():
|
||||||
|
# Free blocks
|
||||||
|
CACHE_MANAGER.free(self.block_tables[i])
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
past_key_values = self.past_key_values[past_indices]
|
|
||||||
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
start_seq = start_seq.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
end_seq = end_seq.to(device)
|
|
||||||
past_present_indices = end_seq - 1
|
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
@ -377,51 +470,74 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_present_indices=past_present_indices,
|
|
||||||
start_seq=start_seq,
|
|
||||||
end_seq=end_seq,
|
|
||||||
start_seq_prefill=None,
|
start_seq_prefill=None,
|
||||||
end_seq_prefill=None,
|
end_seq_prefill=None,
|
||||||
start_seq_q=start_seq_q,
|
block_tables=block_tables,
|
||||||
end_seq_q=end_seq_q,
|
block_tables_tensor=block_tables_tensor,
|
||||||
|
start_slots=start_slots,
|
||||||
|
slots=slots,
|
||||||
|
slot_indices=slot_indices,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
past_key_values=past_key_values,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=cumulative_max_length,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||||
|
global CACHE_MANAGER
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
total_batch_size = sum([len(b) for b in batches])
|
total_batch_size = 0
|
||||||
|
total_slots = 0
|
||||||
dtype = batches[0].past_key_values.dtype
|
max_blocks = 0
|
||||||
device = batches[0].input_ids.device
|
max_length = 0
|
||||||
|
max_seqlen = 0
|
||||||
|
for b in batches:
|
||||||
|
total_batch_size += len(b)
|
||||||
|
total_slots += len(b.slots)
|
||||||
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
|
max_length = max(
|
||||||
|
max_length,
|
||||||
|
max(
|
||||||
|
input_length
|
||||||
|
+ stopping_criteria.max_new_tokens
|
||||||
|
- stopping_criteria.current_tokens
|
||||||
|
for input_length, stopping_criteria in zip(
|
||||||
|
b.input_lengths, b.stopping_criterias
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
start_seq = batches[0].start_seq.new_empty(total_batch_size)
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
end_seq = batches[0].end_seq.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
start_seq_q = torch.arange(
|
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||||
0, total_batch_size, device=device, dtype=torch.int32
|
total_batch_size
|
||||||
|
)
|
||||||
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
|
(total_batch_size, max_blocks)
|
||||||
|
)
|
||||||
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
end_seq_q = start_seq_q + 1
|
|
||||||
max_seqlen = 0
|
|
||||||
past_key_values = []
|
|
||||||
|
|
||||||
|
start_slots = []
|
||||||
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
@ -433,8 +549,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_batch_size = 0
|
cumulative_batch_size = 0
|
||||||
max_tokens = 0
|
cumulative_slots = 0
|
||||||
max_length = 0
|
|
||||||
|
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
@ -448,16 +563,27 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
start_index = cumulative_batch_size
|
start_index = cumulative_batch_size
|
||||||
end_index = cumulative_batch_size + len(batch)
|
end_index = cumulative_batch_size + len(batch)
|
||||||
|
slots_start_index = cumulative_slots
|
||||||
|
slots_end_index = cumulative_slots + len(batch.slots)
|
||||||
|
|
||||||
# Copy tensors (GPU)
|
# Copy tensors (GPU)
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
|
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||||
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
start_seq[start_index:end_index] = batch.start_seq + max_tokens
|
all_input_ids_tensor[
|
||||||
end_seq[start_index:end_index] = batch.end_seq + max_tokens
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
|
] = batch.all_input_ids_tensor[:, :max_length]
|
||||||
|
|
||||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
block_tables_tensor[
|
||||||
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
|
||||||
|
block_tables.extend(batch.block_tables)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
@ -466,43 +592,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
past_key_values.append(batch.past_key_values)
|
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
max_tokens += batch.max_tokens
|
cumulative_slots += len(batch.slots)
|
||||||
max_length = max(
|
|
||||||
max_length,
|
|
||||||
max(
|
|
||||||
input_length
|
|
||||||
+ stopping_criteria.max_new_tokens
|
|
||||||
- stopping_criteria.current_tokens
|
|
||||||
for input_length, stopping_criteria in zip(
|
|
||||||
batch.input_lengths, batch.stopping_criterias
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_values = torch.cat(past_key_values, dim=0)
|
start_slots = torch.concat(start_slots)
|
||||||
past_present_indices = end_seq - 1
|
|
||||||
|
|
||||||
all_input_ids_tensor = torch.zeros(
|
|
||||||
(total_batch_size, max_length), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
cumulative_batch_size = 0
|
|
||||||
for i, batch in enumerate(batches):
|
|
||||||
start_index = cumulative_batch_size
|
|
||||||
end_index = cumulative_batch_size + len(batch)
|
|
||||||
|
|
||||||
all_input_ids_tensor[
|
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
|
||||||
] = batch.all_input_ids_tensor[:, :max_length]
|
|
||||||
|
|
||||||
cumulative_batch_size += len(batch)
|
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype=dtype, device=device
|
next_token_chooser_parameters,
|
||||||
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
|
device=batches[0].next_token_chooser.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
@ -511,28 +611,33 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_present_indices=past_present_indices,
|
|
||||||
start_seq=start_seq,
|
|
||||||
end_seq=end_seq,
|
|
||||||
start_seq_prefill=None,
|
start_seq_prefill=None,
|
||||||
end_seq_prefill=None,
|
end_seq_prefill=None,
|
||||||
start_seq_q=start_seq_q,
|
block_tables=block_tables,
|
||||||
end_seq_q=end_seq_q,
|
block_tables_tensor=block_tables_tensor,
|
||||||
|
start_slots=start_slots,
|
||||||
|
slots=slots,
|
||||||
|
slot_indices=slot_indices,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
past_key_values=past_key_values,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=max_tokens,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
global CACHE_MANAGER
|
||||||
|
# Free blocks
|
||||||
|
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
@ -540,32 +645,24 @@ class FlashCausalLMBatch(Batch):
|
|||||||
class FlashCausalLM(Model):
|
class FlashCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_cls: Type[PreTrainedModel],
|
model: torch.nn.Module,
|
||||||
model_id: str,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
revision: Optional[str] = None,
|
num_layers: int,
|
||||||
quantize: Optional[str] = None,
|
num_heads: int,
|
||||||
trust_remote_code: bool = False,
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
rank: int = 0,
|
||||||
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
self.num_heads = num_heads
|
||||||
device = torch.device("cuda")
|
self.head_size = head_size
|
||||||
dtype = torch.float16
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
global CACHE_MANAGER
|
||||||
model_id,
|
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||||
revision=revision,
|
CACHE_MANAGER = CacheManager(
|
||||||
padding_side="left",
|
1000, num_layers, num_heads, head_size, dtype, device
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
)
|
||||||
model = model_cls.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -573,6 +670,8 @@ class FlashCausalLM(Model):
|
|||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -588,28 +687,27 @@ class FlashCausalLM(Model):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
start_seq: torch.Tensor,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq: torch.Tensor,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q: Optional[torch.Tensor],
|
block_tables: torch.Tensor,
|
||||||
end_seq_q: Optional[torch.Tensor],
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
past_present_indices: torch.Tensor,
|
|
||||||
past_key_values: Optional = None,
|
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
global CACHE_MANAGER
|
||||||
|
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
start_seq=start_seq,
|
start_seq_prefill=start_seq_prefill,
|
||||||
end_seq=end_seq,
|
end_seq_prefill=end_seq_prefill,
|
||||||
start_seq_q=start_seq_q,
|
kv_cache=CACHE_MANAGER.kv_cache,
|
||||||
end_seq_q=end_seq_q,
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
past_present_indices=past_present_indices,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
pre_allocate_past_size=pre_allocate_past_size,
|
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -617,31 +715,18 @@ class FlashCausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.start_seq_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
if prefill:
|
out = self.forward(
|
||||||
# Ask to pre-allocate kv to its max size
|
|
||||||
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
|
|
||||||
pre_allocate_past_size = batch.max_tokens
|
|
||||||
start_seq = batch.start_seq_prefill
|
|
||||||
end_seq = batch.end_seq_prefill
|
|
||||||
else:
|
|
||||||
pre_allocate_past_size = None
|
|
||||||
start_seq = batch.start_seq
|
|
||||||
end_seq = batch.end_seq
|
|
||||||
|
|
||||||
out, present = self.forward(
|
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
start_seq,
|
batch.start_seq_prefill,
|
||||||
end_seq,
|
batch.end_seq_prefill,
|
||||||
batch.start_seq_q,
|
batch.block_tables_tensor,
|
||||||
batch.end_seq_q,
|
batch.slots[batch.slot_indices],
|
||||||
|
batch.input_lengths_tensor,
|
||||||
batch.max_seqlen,
|
batch.max_seqlen,
|
||||||
batch.past_present_indices,
|
|
||||||
batch.past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
batch.prefill_head_indices,
|
batch.prefill_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -662,12 +747,8 @@ class FlashCausalLM(Model):
|
|||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
# Create batch.start_seq_q and batch.end_seq_q for decode
|
|
||||||
batch.start_seq_q = torch.arange(
|
|
||||||
0, len(batch), device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
batch.end_seq_q = batch.start_seq_q + 1
|
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||||
|
batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1]
|
||||||
# We do not need start_seq_prefill and end_seq_prefill anymore
|
# We do not need start_seq_prefill and end_seq_prefill anymore
|
||||||
batch.start_seq_prefill = None
|
batch.start_seq_prefill = None
|
||||||
batch.end_seq_prefill = None
|
batch.end_seq_prefill = None
|
||||||
@ -731,8 +812,8 @@ class FlashCausalLM(Model):
|
|||||||
# Set values in batch
|
# Set values in batch
|
||||||
batch.input_ids = next_input_ids
|
batch.input_ids = next_input_ids
|
||||||
batch.position_ids = next_position_ids + 1
|
batch.position_ids = next_position_ids + 1
|
||||||
batch.past_present_indices = batch.end_seq
|
batch.input_lengths_tensor += 1
|
||||||
batch.end_seq = batch.end_seq + 1
|
batch.slot_indices += 1
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
@ -755,7 +836,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.all_input_ids_tensor,
|
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
@ -770,7 +850,6 @@ class FlashCausalLM(Model):
|
|||||||
read_offset,
|
read_offset,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
all_input_ids_tensor,
|
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
@ -845,19 +924,20 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
new_input_length = input_length + 1
|
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = input_length + 1
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
|
||||||
|
if stopped:
|
||||||
|
batch.cleanup()
|
||||||
|
# No need to return a batch if we know that all requests stopped
|
||||||
|
return generations, None
|
||||||
|
|
||||||
batch.prefill_cu_outlens = None
|
batch.prefill_cu_outlens = None
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
batch.past_key_values = present
|
|
||||||
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
return generations, batch
|
||||||
return generations, batch if not stopped else None
|
|
||||||
|
@ -64,10 +64,12 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
num_layers=len(model.model.layers),
|
||||||
|
num_heads=model.model.num_heads,
|
||||||
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -52,8 +52,11 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group,
|
filenames,
|
||||||
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
model = FlashSantacoderForCausalLM(config, weights)
|
model = FlashSantacoderForCausalLM(config, weights)
|
||||||
|
@ -35,6 +35,9 @@ class Batch(ABC):
|
|||||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.seeds = seeds
|
self.seeds = seeds
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
|
@ -5,7 +5,14 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
filenames: List[Path],
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
process_group,
|
||||||
|
aliases: Optional[Dict[str, List[str]]] = None,
|
||||||
|
):
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
@ -43,7 +50,7 @@ class Weights:
|
|||||||
return str(filename), tensor_name
|
return str(filename), tensor_name
|
||||||
|
|
||||||
def _get_slice(self, tensor_name: str):
|
def _get_slice(self, tensor_name: str):
|
||||||
filename, tensor_name= self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
return slice_
|
return slice_
|
||||||
@ -94,12 +101,20 @@ class Weights:
|
|||||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
qweight = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
|
qzeros = torch.cat(
|
||||||
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
|
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
scales = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
for w2 in w[1:]:
|
for w2 in w[1:]:
|
||||||
torch.testing.assert_close(w2, w[0])
|
torch.testing.assert_close(w2, w[0])
|
||||||
@ -118,7 +133,9 @@ class Weights:
|
|||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user