mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feedback loop
This commit is contained in:
parent
ff4155dfea
commit
c8a033b636
@ -218,8 +218,13 @@ impl Client {
|
|||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
cached_batch,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
@ -237,11 +242,7 @@ impl Client {
|
|||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
let request = tonic::Request::new(DecodeRequest {
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
batch: None,
|
|
||||||
batches,
|
|
||||||
})
|
|
||||||
.inject_context();
|
|
||||||
let response = self.stub.decode(request).await?.into_inner();
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
|
@ -134,11 +134,12 @@ impl ShardedClient {
|
|||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
@ -256,7 +257,7 @@ impl Health for ShardedClient {
|
|||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
max_blocks: 1,
|
max_blocks: 1,
|
||||||
};
|
};
|
||||||
self.clone().prefill(batch).await?;
|
self.clone().prefill(batch, None).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -36,18 +36,14 @@ impl BackendV2 {
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
|
||||||
attention
|
let block_size = match attention.as_str() {
|
||||||
.parse()
|
"flashinfer" => 1,
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
"flashdecoding" => 256,
|
||||||
} else {
|
"paged" => 16,
|
||||||
Attention::Paged
|
_ => unreachable!(),
|
||||||
};
|
|
||||||
let block_size = if attention == Attention::FlashDecoding {
|
|
||||||
256
|
|
||||||
} else {
|
|
||||||
16
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
|
use crate::client::{
|
||||||
|
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
|
||||||
|
};
|
||||||
use crate::queue::{Entry, Queue};
|
use crate::queue::{Entry, Queue};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -31,32 +33,22 @@ impl BackendV3 {
|
|||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
requires_padding: bool,
|
shard_info: InfoResponse,
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
support_chunking: bool,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if support_chunking {
|
if shard_info.support_chunking {
|
||||||
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||||
}
|
}
|
||||||
|
|
||||||
let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
|
let block_size = shard_info.block_size;
|
||||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
|
||||||
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
|
|
||||||
|
|
||||||
let attention: Attention = attention
|
|
||||||
.parse()
|
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
|
||||||
let block_size = attention.block_size();
|
|
||||||
|
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
requires_padding,
|
shard_info.requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
prefix_caching,
|
shard_info.use_prefix_caching,
|
||||||
window_size,
|
shard_info.window_size,
|
||||||
speculate,
|
shard_info.speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
support_chunking,
|
shard_info.support_chunking,
|
||||||
);
|
);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
@ -68,7 +60,7 @@ impl BackendV3 {
|
|||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
support_chunking,
|
shard_info.support_chunking,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
batching_task_notifier.clone(),
|
batching_task_notifier.clone(),
|
||||||
));
|
));
|
||||||
@ -154,7 +146,7 @@ pub(crate) async fn batching_task(
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
@ -175,7 +167,8 @@ pub(crate) async fn batching_task(
|
|||||||
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||||
// Since the next batch will be concatenated with the current batch,
|
// Since the next batch will be concatenated with the current batch,
|
||||||
// the current batch tokens must be subtracted to the prefill budget
|
// the current batch tokens must be subtracted to the prefill budget
|
||||||
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
|
let prefill_token_budget =
|
||||||
|
max_batch_prefill_tokens.saturating_sub(current_tokens);
|
||||||
// We can ignore min_size and max_size
|
// We can ignore min_size and max_size
|
||||||
// Models than rely on max_size cannot support chunking
|
// Models than rely on max_size cannot support chunking
|
||||||
// Regarding min_size, chunking allow us to consistently run at the compute
|
// Regarding min_size, chunking allow us to consistently run at the compute
|
||||||
@ -199,10 +192,8 @@ pub(crate) async fn batching_task(
|
|||||||
(min_size, max_size, max_batch_prefill_tokens)
|
(min_size, max_size, max_batch_prefill_tokens)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut additional_batch = None;
|
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((new_entries, new_batch, span)) = queue
|
||||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
@ -218,11 +209,11 @@ pub(crate) async fn batching_task(
|
|||||||
};
|
};
|
||||||
counter.increment(1);
|
counter.increment(1);
|
||||||
}
|
}
|
||||||
|
let cached_batch = if support_chunking {
|
||||||
if support_chunking {
|
// Concat current batch to the new one
|
||||||
entries.extend(new_entries);
|
batches.pop()
|
||||||
additional_batch = Some(new_batch);
|
|
||||||
} else {
|
} else {
|
||||||
|
// Request are waiting only if we don't support chunking
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
// Create a new span to add the info that this entry is waiting
|
// Create a new span to add the info that this entry is waiting
|
||||||
// because a new batch is being computed
|
// because a new batch is being computed
|
||||||
@ -233,18 +224,23 @@ pub(crate) async fn batching_task(
|
|||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_waiting_span);
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
});
|
});
|
||||||
|
None
|
||||||
|
};
|
||||||
|
entries.extend(new_entries);
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
let new_cached_batch =
|
||||||
|
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
waiting_tokens = 1;
|
waiting_tokens = 1;
|
||||||
// Extend current batch with the new batch
|
// Extend current batch with the new batch
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
entries.extend(new_entries);
|
|
||||||
batches.push(new_cached_batch);
|
batches.push(new_cached_batch);
|
||||||
}
|
} else if support_chunking {
|
||||||
|
// New cached batch is empty, no work left
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,7 +258,7 @@ pub(crate) async fn batching_task(
|
|||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = decode(&mut client, additional_batch, batches, &mut entries)
|
cached_batch = decode(&mut client, batches, &mut entries)
|
||||||
.instrument(next_batch_span)
|
.instrument(next_batch_span)
|
||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
@ -277,13 +273,14 @@ pub(crate) async fn batching_task(
|
|||||||
async fn prefill(
|
async fn prefill(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch, cached_batch).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
@ -292,6 +289,10 @@ async fn prefill(
|
|||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||||
.record(timings.forward.as_secs_f64());
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
@ -316,7 +317,6 @@ async fn prefill(
|
|||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn decode(
|
async fn decode(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Option<Batch>,
|
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
@ -324,7 +324,7 @@ async fn decode(
|
|||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||||
|
|
||||||
match client.decode(batch, batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
|
@ -218,13 +218,23 @@ impl Client {
|
|||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
cached_batch,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
response.batch,
|
response.batch,
|
||||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
PrefillTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,10 +245,9 @@ impl Client {
|
|||||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Option<Batch>,
|
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context();
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
let response = self.stub.decode(request).await?.into_inner();
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
@ -254,14 +263,16 @@ impl Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct PrefillTimings {
|
pub struct PrefillTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
pub forward: Duration,
|
pub forward: Duration,
|
||||||
pub decode: Duration,
|
pub decode: Duration,
|
||||||
pub total: Duration,
|
pub total: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PrefillTimings {
|
impl PrefillTimings {
|
||||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
concat: concat_ns.map(Duration::from_nanos),
|
||||||
forward: Duration::from_nanos(forward_ns),
|
forward: Duration::from_nanos(forward_ns),
|
||||||
decode: Duration::from_nanos(decode_ns),
|
decode: Duration::from_nanos(decode_ns),
|
||||||
total: Duration::from_nanos(total_ns),
|
total: Duration::from_nanos(total_ns),
|
||||||
|
@ -135,11 +135,12 @@ impl ShardedClient {
|
|||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
@ -167,13 +168,12 @@ impl ShardedClient {
|
|||||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Option<Batch>,
|
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.decode(batch.clone(), batches.clone())))
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||||
@ -246,7 +246,7 @@ impl Health for ShardedClient {
|
|||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
max_blocks: 1,
|
max_blocks: 1,
|
||||||
};
|
};
|
||||||
self.clone().prefill(batch).await?;
|
self.clone().prefill(batch, None).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,6 +31,12 @@ pub struct BackendInfo {
|
|||||||
pub max_batch_size: Option<usize>,
|
pub max_batch_size: Option<usize>,
|
||||||
#[schema(example = "false")]
|
#[schema(example = "false")]
|
||||||
pub support_chunking: bool,
|
pub support_chunking: bool,
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub prefix_caching: bool,
|
||||||
|
#[schema(example = "flashinfer")]
|
||||||
|
pub attention_impl: String,
|
||||||
|
#[schema(example = "1")]
|
||||||
|
pub block_size: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
@ -113,6 +119,9 @@ pub async fn connect_backend(
|
|||||||
model_dtype: shard_info.dtype.clone(),
|
model_dtype: shard_info.dtype.clone(),
|
||||||
speculate: shard_info.speculate as usize,
|
speculate: shard_info.speculate as usize,
|
||||||
support_chunking: shard_info.support_chunking,
|
support_chunking: shard_info.support_chunking,
|
||||||
|
prefix_caching: shard_info.use_prefix_caching,
|
||||||
|
attention_impl: shard_info.attention_impl.clone(),
|
||||||
|
block_size: shard_info.block_size,
|
||||||
};
|
};
|
||||||
|
|
||||||
let backend = BackendV3::new(
|
let backend = BackendV3::new(
|
||||||
@ -122,10 +131,7 @@ pub async fn connect_backend(
|
|||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
shard_info.requires_padding,
|
shard_info,
|
||||||
shard_info.window_size,
|
|
||||||
shard_info.speculate,
|
|
||||||
shard_info.support_chunking,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
tracing::info!("Using backend V3");
|
tracing::info!("Using backend V3");
|
||||||
|
@ -89,6 +89,10 @@ impl Queue {
|
|||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
|
if prefill_token_budget == 0 || token_budget == 0 {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
// Create response channel
|
// Create response channel
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
// Send next batch command to the background task managing the state
|
// Send next batch command to the background task managing the state
|
||||||
|
@ -174,7 +174,7 @@ async fn prefill(
|
|||||||
|
|
||||||
// Run prefill
|
// Run prefill
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
|
||||||
|
|
||||||
// Get latency
|
// Get latency
|
||||||
let latency = start_time.elapsed();
|
let latency = start_time.elapsed();
|
||||||
|
@ -35,6 +35,9 @@ message InfoResponse {
|
|||||||
optional uint32 window_size = 4;
|
optional uint32 window_size = 4;
|
||||||
uint32 speculate = 5;
|
uint32 speculate = 5;
|
||||||
bool support_chunking = 6;
|
bool support_chunking = 6;
|
||||||
|
bool use_prefix_caching = 7;
|
||||||
|
string attention_impl = 8;
|
||||||
|
uint32 block_size = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
@ -225,6 +228,8 @@ message FilterBatchResponse {
|
|||||||
message PrefillRequest {
|
message PrefillRequest {
|
||||||
/// Batch
|
/// Batch
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
|
/// Optional cached batch
|
||||||
|
CachedBatch cached_batch = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message PrefillResponse {
|
message PrefillResponse {
|
||||||
@ -238,13 +243,13 @@ message PrefillResponse {
|
|||||||
uint64 decode_ns = 4;
|
uint64 decode_ns = 4;
|
||||||
/// Total elapsed time in nanoseconds
|
/// Total elapsed time in nanoseconds
|
||||||
uint64 total_ns = 5;
|
uint64 total_ns = 5;
|
||||||
|
/// Concatenate elapsed time in nanoseconds
|
||||||
|
optional uint64 concat_ns = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DecodeRequest {
|
message DecodeRequest {
|
||||||
/// Cached batches
|
/// Cached batches
|
||||||
repeated CachedBatch batches = 1;
|
repeated CachedBatch batches = 1;
|
||||||
/// Optional Batch
|
|
||||||
optional Batch batch = 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message DecodeResponse {
|
message DecodeResponse {
|
||||||
|
@ -18,45 +18,6 @@ use tracing::warn;
|
|||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
#[derive(PartialEq)]
|
|
||||||
pub enum Attention {
|
|
||||||
Paged,
|
|
||||||
FlashDecoding,
|
|
||||||
FlashInfer,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attention {
|
|
||||||
pub fn block_size(&self) -> u32 {
|
|
||||||
match self {
|
|
||||||
Attention::FlashDecoding => 256,
|
|
||||||
Attention::FlashInfer => 1,
|
|
||||||
Attention::Paged => 16,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ParseError;
|
|
||||||
|
|
||||||
impl std::fmt::Display for ParseError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "Cannot parse attention value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl std::error::Error for ParseError {}
|
|
||||||
|
|
||||||
impl std::str::FromStr for Attention {
|
|
||||||
type Err = ParseError;
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
||||||
match s {
|
|
||||||
"paged" => Ok(Attention::Paged),
|
|
||||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
|
||||||
"flashinfer" => Ok(Attention::FlashInfer),
|
|
||||||
_ => Err(ParseError),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
|
@ -76,7 +76,7 @@ class CausalLMBatch(Batch):
|
|||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
current_tokens=len(self),
|
current_tokens=len(self.input_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -171,7 +171,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Will be set by `generate_token` and reset after each prefill forward
|
# Will be set by `generate_token` and reset after each prefill forward
|
||||||
prefill_cu_outlens: Optional[List[int]]
|
prefill_cu_outlens: Optional[List[int]]
|
||||||
# Will be set by `generate_token` and reset after each prefill forward
|
# Will be set by `generate_token` and reset after each prefill forward
|
||||||
prefill_tokens: List[Optional[Tokens]]
|
prefill_logprob_tokens: List[Optional[Tokens]]
|
||||||
|
|
||||||
# Prefixes
|
# Prefixes
|
||||||
prefix_ids: List[List[int]]
|
prefix_ids: List[List[int]]
|
||||||
@ -290,8 +290,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_length <= prompt_length
|
prefix_length <= prompt_length
|
||||||
), f"Prefix {prefix_length} vs input {prompt_length}"
|
), f"Prefix {prefix_length} vs input {prompt_length}"
|
||||||
if prefix_length == prompt_length:
|
if prefix_length == prompt_length:
|
||||||
assert prefix_length > 0
|
assert False, "unreachable"
|
||||||
prefix_length -= 1
|
|
||||||
if prefix_length + postfix_length < prompt_length:
|
if prefix_length + postfix_length < prompt_length:
|
||||||
# FIXME: speculate is not supported for context chunking at the moment
|
# FIXME: speculate is not supported for context chunking at the moment
|
||||||
assert speculate == 0
|
assert speculate == 0
|
||||||
@ -303,7 +302,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_length : prefix_length + postfix_length
|
prefix_length : prefix_length + postfix_length
|
||||||
]
|
]
|
||||||
|
|
||||||
postfix_length = len(postfix_ids)
|
assert (
|
||||||
|
len(postfix_ids) == postfix_length
|
||||||
|
), "Rust and Python tokenizers are not aligned"
|
||||||
postfix_lengths.append(postfix_length)
|
postfix_lengths.append(postfix_length)
|
||||||
|
|
||||||
prefix_offsets.append(prompt_length - 5)
|
prefix_offsets.append(prompt_length - 5)
|
||||||
@ -394,7 +395,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_current_length=max_current_length,
|
max_current_length=max_current_length,
|
||||||
prefilling=True,
|
prefilling=True,
|
||||||
prefilling_mask=[True] * len(pb.requests),
|
prefilling_mask=[True] * len(pb.requests),
|
||||||
prefill_tokens=[None] * len(pb.requests),
|
prefill_logprob_tokens=[None] * len(pb.requests),
|
||||||
postfix_lengths=postfix_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
prompt_lengths=prompt_lengths,
|
prompt_lengths=prompt_lengths,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
@ -475,7 +476,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
prefilling_mask = []
|
prefilling_mask = []
|
||||||
prefill_tokens = []
|
prefill_logprob_tokens = []
|
||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
@ -518,7 +519,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
top_n_tokens.append(self.top_n_tokens[idx])
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
prefill_tokens.append(self.prefill_tokens[idx])
|
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
|
||||||
|
|
||||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||||
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
||||||
@ -611,7 +612,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
prefill_tokens=prefill_tokens,
|
prefill_logprob_tokens=prefill_logprob_tokens,
|
||||||
prompt_lengths=prompt_lengths,
|
prompt_lengths=prompt_lengths,
|
||||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||||
postfix_lengths=postfix_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
@ -726,7 +727,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
prefill_tokens = []
|
prefill_logprob_tokens = []
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
fsm_grammar_states = []
|
fsm_grammar_states = []
|
||||||
@ -814,7 +815,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
read_offsets.extend(batch.read_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
|
||||||
prefill_tokens.extend(batch.prefill_tokens)
|
prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
|
||||||
|
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
|
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
|
||||||
@ -869,7 +870,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
prefill_tokens=prefill_tokens,
|
prefill_logprob_tokens=prefill_logprob_tokens,
|
||||||
prompt_lengths=prompt_lengths,
|
prompt_lengths=prompt_lengths,
|
||||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||||
postfix_lengths=postfix_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
@ -1769,9 +1770,10 @@ class FlashCausalLM(Model):
|
|||||||
if get_support_chunking():
|
if get_support_chunking():
|
||||||
next_prefilling_mask = []
|
next_prefilling_mask = []
|
||||||
# Budget in tokens for the next batch
|
# Budget in tokens for the next batch
|
||||||
# We remove len(batch) to always have enough space for at least a single decode
|
# We remove (len(batch) - 1) to always have enough space for at least a single decode
|
||||||
# for the remaining requests
|
# for the remaining requests -1 because the first request does not need to be removed from the budget
|
||||||
batch_budget = get_max_prefill_tokens() - len(batch)
|
# (ex: you have one request in the batch, you want it to take the full budget not budget -1)
|
||||||
|
batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
|
||||||
# We reverse to prioritize older requests
|
# We reverse to prioritize older requests
|
||||||
# zip() is not reversible so reverse the underlying lists instead
|
# zip() is not reversible so reverse the underlying lists instead
|
||||||
for prefix_length, postfix_length, prompt_length in zip(
|
for prefix_length, postfix_length, prompt_length in zip(
|
||||||
@ -1790,6 +1792,7 @@ class FlashCausalLM(Model):
|
|||||||
finished_prefilling = False
|
finished_prefilling = False
|
||||||
next_prefilling_mask.append(True)
|
next_prefilling_mask.append(True)
|
||||||
else:
|
else:
|
||||||
|
# FIXME: use true number of accepted tokens instead of 1
|
||||||
# Since speculation will be turned off, this is always true
|
# Since speculation will be turned off, this is always true
|
||||||
next_chunk_length = 1
|
next_chunk_length = 1
|
||||||
next_prefilling_mask.append(False)
|
next_prefilling_mask.append(False)
|
||||||
@ -1807,14 +1810,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.prefilling = not finished_prefilling
|
batch.prefilling = not finished_prefilling
|
||||||
batch.prefilling_mask = next_prefilling_mask
|
batch.prefilling_mask = next_prefilling_mask
|
||||||
|
|
||||||
# Turn off speculative if some requests are still prefilling
|
|
||||||
# It makes the logic easier to follow
|
|
||||||
if prefill and not finished_prefilling:
|
|
||||||
speculate = 0
|
|
||||||
speculative_logits = None
|
|
||||||
else:
|
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
|
|
||||||
(
|
(
|
||||||
next_input_ids,
|
next_input_ids,
|
||||||
next_token_logprobs,
|
next_token_logprobs,
|
||||||
@ -2045,18 +2041,18 @@ class FlashCausalLM(Model):
|
|||||||
# this state to be stable
|
# this state to be stable
|
||||||
if request.id % self.world_size == self.rank:
|
if request.id % self.world_size == self.rank:
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill and request.prefill_logprobs:
|
if request_prefilling and request.prefill_logprobs:
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
request_prefill_tokens = batch.prefill_tokens[i]
|
|
||||||
|
|
||||||
request_prefill_logprobs = prefill_logprobs[
|
request_prefill_logprobs = prefill_logprobs[
|
||||||
out_start_index : out_end_index - 1
|
out_start_index : out_end_index - 1
|
||||||
]
|
]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
|
|
||||||
if request_prefill_tokens is None:
|
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
||||||
|
|
||||||
|
if past_prefill_logprob_tokens is None:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
request_prefill_logprobs = [float("nan")] * (
|
request_prefill_logprobs = [float("nan")] * (
|
||||||
len(prefix_ids) + 1
|
len(prefix_ids) + 1
|
||||||
@ -2069,18 +2065,20 @@ class FlashCausalLM(Model):
|
|||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_tokens = Tokens(
|
prefill_logprob_tokens = Tokens(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
request_prefill_logprobs,
|
request_prefill_logprobs,
|
||||||
prefill_texts,
|
prefill_texts,
|
||||||
is_special=[],
|
is_special=[],
|
||||||
)
|
)
|
||||||
if request_prefill_tokens is not None:
|
if past_prefill_logprob_tokens is not None:
|
||||||
prefill_tokens = request_prefill_tokens + prefill_tokens
|
prefill_logprob_tokens = (
|
||||||
|
past_prefill_logprob_tokens + prefill_logprob_tokens
|
||||||
|
)
|
||||||
|
|
||||||
batch.prefill_tokens[i] = prefill_tokens
|
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
|
||||||
else:
|
else:
|
||||||
batch.prefill_tokens[i] = None
|
batch.prefill_logprob_tokens[i] = None
|
||||||
|
|
||||||
# If it is, the tokens we decoded should be ignored
|
# If it is, the tokens we decoded should be ignored
|
||||||
if request_prefilling:
|
if request_prefilling:
|
||||||
@ -2178,7 +2176,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
batch.prefill_tokens[i],
|
batch.prefill_logprob_tokens[i],
|
||||||
Tokens(
|
Tokens(
|
||||||
_next_token_ids,
|
_next_token_ids,
|
||||||
_next_token_logprobs,
|
_next_token_logprobs,
|
||||||
|
@ -7,6 +7,7 @@ from collections import defaultdict
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import ATTENTION, PREFIX_CACHING, BLOCK_SIZE
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from text_generation_server.utils.prefill_chunking import set_support_chunking
|
from text_generation_server.utils.prefill_chunking import set_support_chunking
|
||||||
@ -94,6 +95,9 @@ class Model(ABC):
|
|||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
speculate=self.speculate,
|
speculate=self.speculate,
|
||||||
support_chunking=self.support_chunking,
|
support_chunking=self.support_chunking,
|
||||||
|
use_prefix_caching=PREFIX_CACHING,
|
||||||
|
attention_impl=ATTENTION,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -80,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
current_tokens=len(self),
|
current_tokens=len(self.input_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -153,6 +153,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
concat_ns = None
|
||||||
|
if self.model.support_chunking:
|
||||||
|
if request.HasField("cached_batch"):
|
||||||
|
cached_batch = self.cache.pop(request.cached_batch.id)
|
||||||
|
if cached_batch is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Batch ID {request.cached_batch.id} not found in cache."
|
||||||
|
)
|
||||||
|
start_concat = time.time_ns()
|
||||||
|
batch = self.model.batch_type.concatenate([batch, cached_batch])
|
||||||
|
concat_ns = time.time_ns() - start_concat
|
||||||
|
|
||||||
generations, next_batch, timings = self.model.generate_token(batch)
|
generations, next_batch, timings = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
@ -162,6 +174,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
forward_ns=timings[0],
|
forward_ns=timings[0],
|
||||||
decode_ns=timings[1],
|
decode_ns=timings[1],
|
||||||
total_ns=time.time_ns() - start,
|
total_ns=time.time_ns() - start,
|
||||||
|
concat_ns=concat_ns,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Decode(self, request, context):
|
async def Decode(self, request, context):
|
||||||
@ -179,16 +192,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
raise ValueError("All batches are empty")
|
raise ValueError("All batches are empty")
|
||||||
|
|
||||||
if self.model.support_chunking:
|
|
||||||
if request.HasField("batch"):
|
|
||||||
batch = self.model.batch_type.from_pb(
|
|
||||||
request.batch,
|
|
||||||
self.model.tokenizer,
|
|
||||||
self.model.dtype,
|
|
||||||
self.model.device,
|
|
||||||
)
|
|
||||||
batches.append(batch)
|
|
||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
start_concat = time.time_ns()
|
start_concat = time.time_ns()
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
batch = self.model.batch_type.concatenate(batches)
|
||||||
|
Loading…
Reference in New Issue
Block a user