mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feedback loop
This commit is contained in:
parent
ff4155dfea
commit
c8a033b636
@ -218,8 +218,13 @@ impl Client {
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let request = tonic::Request::new(PrefillRequest {
|
||||
batch: Some(batch),
|
||||
cached_batch,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
@ -237,11 +242,7 @@ impl Client {
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest {
|
||||
batch: None,
|
||||
batches,
|
||||
})
|
||||
.inject_context();
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
|
@ -134,11 +134,12 @@ impl ShardedClient {
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
@ -256,7 +257,7 @@ impl Health for ShardedClient {
|
||||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
self.clone().prefill(batch, None).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
@ -36,18 +36,14 @@ impl BackendV2 {
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
|
||||
let block_size = match attention.as_str() {
|
||||
"flashinfer" => 1,
|
||||
"flashdecoding" => 256,
|
||||
"paged" => 16,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
|
@ -1,12 +1,14 @@
|
||||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||
/// Batching and inference logic
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
|
||||
};
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
@ -31,32 +33,22 @@ impl BackendV3 {
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
support_chunking: bool,
|
||||
shard_info: InfoResponse,
|
||||
) -> Self {
|
||||
if support_chunking {
|
||||
if shard_info.support_chunking {
|
||||
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||
}
|
||||
|
||||
let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
|
||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
|
||||
|
||||
let attention: Attention = attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||
let block_size = attention.block_size();
|
||||
let block_size = shard_info.block_size;
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
shard_info.requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
shard_info.use_prefix_caching,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
max_batch_total_tokens,
|
||||
support_chunking,
|
||||
shard_info.support_chunking,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
@ -68,7 +60,7 @@ impl BackendV3 {
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
support_chunking,
|
||||
shard_info.support_chunking,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
@ -154,7 +146,7 @@ pub(crate) async fn batching_task(
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
@ -175,7 +167,8 @@ pub(crate) async fn batching_task(
|
||||
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||
// Since the next batch will be concatenated with the current batch,
|
||||
// the current batch tokens must be subtracted to the prefill budget
|
||||
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
|
||||
let prefill_token_budget =
|
||||
max_batch_prefill_tokens.saturating_sub(current_tokens);
|
||||
// We can ignore min_size and max_size
|
||||
// Models than rely on max_size cannot support chunking
|
||||
// Regarding min_size, chunking allow us to consistently run at the compute
|
||||
@ -199,10 +192,8 @@ pub(crate) async fn batching_task(
|
||||
(min_size, max_size, max_batch_prefill_tokens)
|
||||
};
|
||||
|
||||
let mut additional_batch = None;
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
if let Some((new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
.await
|
||||
{
|
||||
@ -218,11 +209,11 @@ pub(crate) async fn batching_task(
|
||||
};
|
||||
counter.increment(1);
|
||||
}
|
||||
|
||||
if support_chunking {
|
||||
entries.extend(new_entries);
|
||||
additional_batch = Some(new_batch);
|
||||
let cached_batch = if support_chunking {
|
||||
// Concat current batch to the new one
|
||||
batches.pop()
|
||||
} else {
|
||||
// Request are waiting only if we don't support chunking
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
@ -233,18 +224,23 @@ pub(crate) async fn batching_task(
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
None
|
||||
};
|
||||
entries.extend(new_entries);
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
} else if support_chunking {
|
||||
// New cached batch is empty, no work left
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -262,7 +258,7 @@ pub(crate) async fn batching_task(
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, additional_batch, batches, &mut entries)
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
@ -277,13 +273,14 @@ pub(crate) async fn batching_task(
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
match client.prefill(batch, cached_batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
@ -292,6 +289,10 @@ async fn prefill(
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
@ -316,7 +317,6 @@ async fn prefill(
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batch: Option<Batch>,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
@ -324,7 +324,7 @@ async fn decode(
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batch, batches).await {
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
|
@ -218,13 +218,23 @@ impl Client {
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let request = tonic::Request::new(PrefillRequest {
|
||||
batch: Some(batch),
|
||||
cached_batch,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
PrefillTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
@ -235,10 +245,9 @@ impl Client {
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batch: Option<Batch>,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context();
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
@ -254,14 +263,16 @@ impl Client {
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
|
@ -135,11 +135,12 @@ impl ShardedClient {
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
@ -167,13 +168,12 @@ impl ShardedClient {
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batch: Option<Batch>,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batch.clone(), batches.clone())))
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
@ -246,7 +246,7 @@ impl Health for ShardedClient {
|
||||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
self.clone().prefill(batch, None).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -31,6 +31,12 @@ pub struct BackendInfo {
|
||||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "false")]
|
||||
pub support_chunking: bool,
|
||||
#[schema(example = "false")]
|
||||
pub prefix_caching: bool,
|
||||
#[schema(example = "flashinfer")]
|
||||
pub attention_impl: String,
|
||||
#[schema(example = "1")]
|
||||
pub block_size: u32,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@ -113,6 +119,9 @@ pub async fn connect_backend(
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
support_chunking: shard_info.support_chunking,
|
||||
prefix_caching: shard_info.use_prefix_caching,
|
||||
attention_impl: shard_info.attention_impl.clone(),
|
||||
block_size: shard_info.block_size,
|
||||
};
|
||||
|
||||
let backend = BackendV3::new(
|
||||
@ -122,10 +131,7 @@ pub async fn connect_backend(
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
shard_info.support_chunking,
|
||||
shard_info,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
@ -89,6 +89,10 @@ impl Queue {
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
if prefill_token_budget == 0 || token_budget == 0 {
|
||||
return None;
|
||||
};
|
||||
|
||||
// Create response channel
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
// Send next batch command to the background task managing the state
|
||||
|
@ -174,7 +174,7 @@ async fn prefill(
|
||||
|
||||
// Run prefill
|
||||
let start_time = Instant::now();
|
||||
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
||||
let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
|
||||
|
||||
// Get latency
|
||||
let latency = start_time.elapsed();
|
||||
|
@ -35,6 +35,9 @@ message InfoResponse {
|
||||
optional uint32 window_size = 4;
|
||||
uint32 speculate = 5;
|
||||
bool support_chunking = 6;
|
||||
bool use_prefix_caching = 7;
|
||||
string attention_impl = 8;
|
||||
uint32 block_size = 9;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
@ -225,6 +228,8 @@ message FilterBatchResponse {
|
||||
message PrefillRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
/// Optional cached batch
|
||||
CachedBatch cached_batch = 2;
|
||||
}
|
||||
|
||||
message PrefillResponse {
|
||||
@ -238,13 +243,13 @@ message PrefillResponse {
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
/// Concatenate elapsed time in nanoseconds
|
||||
optional uint64 concat_ns = 6;
|
||||
}
|
||||
|
||||
message DecodeRequest {
|
||||
/// Cached batches
|
||||
repeated CachedBatch batches = 1;
|
||||
/// Optional Batch
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
|
||||
message DecodeResponse {
|
||||
|
@ -18,45 +18,6 @@ use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum Attention {
|
||||
Paged,
|
||||
FlashDecoding,
|
||||
FlashInfer,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn block_size(&self) -> u32 {
|
||||
match self {
|
||||
Attention::FlashDecoding => 256,
|
||||
Attention::FlashInfer => 1,
|
||||
Attention::Paged => 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ParseError;
|
||||
|
||||
impl std::fmt::Display for ParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Cannot parse attention value")
|
||||
}
|
||||
}
|
||||
impl std::error::Error for ParseError {}
|
||||
|
||||
impl std::str::FromStr for Attention {
|
||||
type Err = ParseError;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"paged" => Ok(Attention::Paged),
|
||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
||||
"flashinfer" => Ok(Attention::FlashInfer),
|
||||
_ => Err(ParseError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hub type
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct HubModelInfo {
|
||||
|
@ -76,7 +76,7 @@ class CausalLMBatch(Batch):
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self),
|
||||
current_tokens=len(self.input_ids),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -171,7 +171,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Will be set by `generate_token` and reset after each prefill forward
|
||||
prefill_cu_outlens: Optional[List[int]]
|
||||
# Will be set by `generate_token` and reset after each prefill forward
|
||||
prefill_tokens: List[Optional[Tokens]]
|
||||
prefill_logprob_tokens: List[Optional[Tokens]]
|
||||
|
||||
# Prefixes
|
||||
prefix_ids: List[List[int]]
|
||||
@ -290,8 +290,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_length <= prompt_length
|
||||
), f"Prefix {prefix_length} vs input {prompt_length}"
|
||||
if prefix_length == prompt_length:
|
||||
assert prefix_length > 0
|
||||
prefix_length -= 1
|
||||
assert False, "unreachable"
|
||||
if prefix_length + postfix_length < prompt_length:
|
||||
# FIXME: speculate is not supported for context chunking at the moment
|
||||
assert speculate == 0
|
||||
@ -303,7 +302,9 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_length : prefix_length + postfix_length
|
||||
]
|
||||
|
||||
postfix_length = len(postfix_ids)
|
||||
assert (
|
||||
len(postfix_ids) == postfix_length
|
||||
), "Rust and Python tokenizers are not aligned"
|
||||
postfix_lengths.append(postfix_length)
|
||||
|
||||
prefix_offsets.append(prompt_length - 5)
|
||||
@ -394,7 +395,7 @@ class FlashCausalLMBatch(Batch):
|
||||
max_current_length=max_current_length,
|
||||
prefilling=True,
|
||||
prefilling_mask=[True] * len(pb.requests),
|
||||
prefill_tokens=[None] * len(pb.requests),
|
||||
prefill_logprob_tokens=[None] * len(pb.requests),
|
||||
postfix_lengths=postfix_lengths,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prefix_offsets=prefix_offsets,
|
||||
@ -475,7 +476,7 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets = []
|
||||
|
||||
prefilling_mask = []
|
||||
prefill_tokens = []
|
||||
prefill_logprob_tokens = []
|
||||
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
@ -518,7 +519,7 @@ class FlashCausalLMBatch(Batch):
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
prefill_tokens.append(self.prefill_tokens[idx])
|
||||
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
|
||||
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
||||
@ -611,7 +612,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
prefill_tokens=prefill_tokens,
|
||||
prefill_logprob_tokens=prefill_logprob_tokens,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
@ -726,7 +727,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
prefill_tokens = []
|
||||
prefill_logprob_tokens = []
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
fsm_grammar_states = []
|
||||
@ -814,7 +815,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
|
||||
prefill_tokens.extend(batch.prefill_tokens)
|
||||
prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
|
||||
|
||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
|
||||
@ -869,7 +870,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
prefill_tokens=prefill_tokens,
|
||||
prefill_logprob_tokens=prefill_logprob_tokens,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
@ -1769,9 +1770,10 @@ class FlashCausalLM(Model):
|
||||
if get_support_chunking():
|
||||
next_prefilling_mask = []
|
||||
# Budget in tokens for the next batch
|
||||
# We remove len(batch) to always have enough space for at least a single decode
|
||||
# for the remaining requests
|
||||
batch_budget = get_max_prefill_tokens() - len(batch)
|
||||
# We remove (len(batch) - 1) to always have enough space for at least a single decode
|
||||
# for the remaining requests -1 because the first request does not need to be removed from the budget
|
||||
# (ex: you have one request in the batch, you want it to take the full budget not budget -1)
|
||||
batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
|
||||
# We reverse to prioritize older requests
|
||||
# zip() is not reversible so reverse the underlying lists instead
|
||||
for prefix_length, postfix_length, prompt_length in zip(
|
||||
@ -1790,6 +1792,7 @@ class FlashCausalLM(Model):
|
||||
finished_prefilling = False
|
||||
next_prefilling_mask.append(True)
|
||||
else:
|
||||
# FIXME: use true number of accepted tokens instead of 1
|
||||
# Since speculation will be turned off, this is always true
|
||||
next_chunk_length = 1
|
||||
next_prefilling_mask.append(False)
|
||||
@ -1807,14 +1810,7 @@ class FlashCausalLM(Model):
|
||||
batch.prefilling = not finished_prefilling
|
||||
batch.prefilling_mask = next_prefilling_mask
|
||||
|
||||
# Turn off speculative if some requests are still prefilling
|
||||
# It makes the logic easier to follow
|
||||
if prefill and not finished_prefilling:
|
||||
speculate = 0
|
||||
speculative_logits = None
|
||||
else:
|
||||
speculate = get_speculate()
|
||||
|
||||
(
|
||||
next_input_ids,
|
||||
next_token_logprobs,
|
||||
@ -2045,18 +2041,18 @@ class FlashCausalLM(Model):
|
||||
# this state to be stable
|
||||
if request.id % self.world_size == self.rank:
|
||||
# Prefill
|
||||
if prefill and request.prefill_logprobs:
|
||||
if request_prefilling and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
request_prefill_tokens = batch.prefill_tokens[i]
|
||||
|
||||
request_prefill_logprobs = prefill_logprobs[
|
||||
out_start_index : out_end_index - 1
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
|
||||
if request_prefill_tokens is None:
|
||||
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
||||
|
||||
if past_prefill_logprob_tokens is None:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] * (
|
||||
len(prefix_ids) + 1
|
||||
@ -2069,18 +2065,20 @@ class FlashCausalLM(Model):
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
prefill_tokens = Tokens(
|
||||
prefill_logprob_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
request_prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
if request_prefill_tokens is not None:
|
||||
prefill_tokens = request_prefill_tokens + prefill_tokens
|
||||
if past_prefill_logprob_tokens is not None:
|
||||
prefill_logprob_tokens = (
|
||||
past_prefill_logprob_tokens + prefill_logprob_tokens
|
||||
)
|
||||
|
||||
batch.prefill_tokens[i] = prefill_tokens
|
||||
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
|
||||
else:
|
||||
batch.prefill_tokens[i] = None
|
||||
batch.prefill_logprob_tokens[i] = None
|
||||
|
||||
# If it is, the tokens we decoded should be ignored
|
||||
if request_prefilling:
|
||||
@ -2178,7 +2176,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
batch.prefill_tokens[i],
|
||||
batch.prefill_logprob_tokens[i],
|
||||
Tokens(
|
||||
_next_token_ids,
|
||||
_next_token_logprobs,
|
||||
|
@ -7,6 +7,7 @@ from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.models.globals import ATTENTION, PREFIX_CACHING, BLOCK_SIZE
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.prefill_chunking import set_support_chunking
|
||||
@ -94,6 +95,9 @@ class Model(ABC):
|
||||
window_size=self.sliding_window,
|
||||
speculate=self.speculate,
|
||||
support_chunking=self.support_chunking,
|
||||
use_prefix_caching=PREFIX_CACHING,
|
||||
attention_impl=ATTENTION,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -80,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self),
|
||||
current_tokens=len(self.input_ids),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -153,6 +153,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
concat_ns = None
|
||||
if self.model.support_chunking:
|
||||
if request.HasField("cached_batch"):
|
||||
cached_batch = self.cache.pop(request.cached_batch.id)
|
||||
if cached_batch is None:
|
||||
raise ValueError(
|
||||
f"Batch ID {request.cached_batch.id} not found in cache."
|
||||
)
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate([batch, cached_batch])
|
||||
concat_ns = time.time_ns() - start_concat
|
||||
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
@ -162,6 +174,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
forward_ns=timings[0],
|
||||
decode_ns=timings[1],
|
||||
total_ns=time.time_ns() - start,
|
||||
concat_ns=concat_ns,
|
||||
)
|
||||
|
||||
async def Decode(self, request, context):
|
||||
@ -179,16 +192,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
if len(batches) == 0:
|
||||
raise ValueError("All batches are empty")
|
||||
|
||||
if self.model.support_chunking:
|
||||
if request.HasField("batch"):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
batches.append(batch)
|
||||
|
||||
if len(batches) > 1:
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
|
Loading…
Reference in New Issue
Block a user