mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
FlashCausalLM implem
This commit is contained in:
parent
6983ec9537
commit
73c3903214
@ -224,11 +224,18 @@ message FilterBatchRequest {
|
|||||||
repeated uint64 terminated_request_ids = 3;
|
repeated uint64 terminated_request_ids = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message TerminatedGeneration {
|
||||||
|
// Request ID
|
||||||
|
uint64 id = 1;
|
||||||
|
// Generated text
|
||||||
|
GeneratedText generated_text = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message FilterBatchResponse {
|
message FilterBatchResponse {
|
||||||
/// Filtered Batch (cached)
|
/// Filtered Batch (cached)
|
||||||
CachedBatch batch = 1;
|
CachedBatch batch = 1;
|
||||||
/// Terminated generations
|
/// Terminated generations
|
||||||
repeated GeneratedText terminated_generations = 2;
|
repeated TerminatedGeneration terminated_generations = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ impl Client {
|
|||||||
batch_id: u64,
|
batch_id: u64,
|
||||||
kept_requests: Vec<KeptRequest>,
|
kept_requests: Vec<KeptRequest>,
|
||||||
terminated_request_ids: Vec<u64>,
|
terminated_request_ids: Vec<u64>,
|
||||||
) -> Result<Option<CachedBatch>> {
|
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
|
||||||
let request = tonic::Request::new(FilterBatchRequest {
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
batch_id,
|
batch_id,
|
||||||
kept_requests,
|
kept_requests,
|
||||||
@ -100,7 +100,7 @@ impl Client {
|
|||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
Ok(filtered_batch.batch)
|
Ok((filtered_batch.batch, filtered_batch.terminated_generations))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Warmup on a max size batch
|
/// Warmup on a max size batch
|
||||||
|
@ -8,6 +8,6 @@ pub use client::Client;
|
|||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
use crate::{v3, Health, ShardInfo};
|
use crate::{v3, Health, ShardInfo};
|
||||||
use crate::{ClientError, Result};
|
use crate::{ClientError, Result};
|
||||||
|
|
||||||
use crate::v3::{Chunk, InfoResponse, Input};
|
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
@ -86,7 +86,7 @@ impl ShardedClient {
|
|||||||
batch_id: u64,
|
batch_id: u64,
|
||||||
kept_requests: Vec<KeptRequest>,
|
kept_requests: Vec<KeptRequest>,
|
||||||
terminated_request_ids: Vec<u64>,
|
terminated_request_ids: Vec<u64>,
|
||||||
) -> Result<Option<CachedBatch>> {
|
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
|
@ -5,12 +5,14 @@ use crate::infer::{
|
|||||||
};
|
};
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use crate::{FinishReason, PrefillToken, Token};
|
use crate::{FinishReason, PrefillToken, Token};
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient};
|
use text_generation_client::v3::{
|
||||||
|
Batch, CachedBatch, Generation, KeptRequest, ShardedClient, TerminatedGeneration,
|
||||||
|
};
|
||||||
use text_generation_client::ClientError;
|
use text_generation_client::ClientError;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||||
@ -243,11 +245,38 @@ async fn prefill(
|
|||||||
generation_health.store(true, Ordering::SeqCst);
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
|
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Filter and send finished generations
|
||||||
filter_send_generations(generations, entries);
|
let filtered_stream_responses = filter_send_ended_generations(generations, entries);
|
||||||
|
|
||||||
|
// Iterate on intermediate generations
|
||||||
|
for (id, stream_responses) in filtered_stream_responses {
|
||||||
|
// Get entry
|
||||||
|
let entry = entries
|
||||||
|
.get_mut(&id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
// Send intermediate responses
|
||||||
|
if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
|
err
|
||||||
|
}) {
|
||||||
|
// Sending failed, remove entry
|
||||||
|
entries
|
||||||
|
.remove(&id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries, false).await;
|
let next_batch = match next_batch {
|
||||||
|
Some(batch) if batch.size as usize != entries.len() => {
|
||||||
|
let (filtered_batch, _) =
|
||||||
|
filter_batch(client, batch, entries, &IntMap::default()).await;
|
||||||
|
filtered_batch
|
||||||
|
}
|
||||||
|
batch => batch,
|
||||||
|
};
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||||
@ -285,13 +314,32 @@ async fn decode(
|
|||||||
generation_health.store(true, Ordering::SeqCst);
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
|
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
|
||||||
filter_send_generations(generations, entries);
|
|
||||||
|
|
||||||
let updated = filter_update_allocations(entries).await;
|
// Filter and send finished generations
|
||||||
|
let mut filtered_stream_responses = filter_send_ended_generations(generations, entries);
|
||||||
|
// Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be
|
||||||
|
// re-allocated,
|
||||||
|
// Allocated new blocks for entries that go over their allocation
|
||||||
|
// Filter entries that couldn't be re-allocated and add them to `terminated_entries`
|
||||||
|
let (force_update, terminated_entries) =
|
||||||
|
filter_send_update_allocations(entries, &mut filtered_stream_responses);
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
let next_batch = match next_batch {
|
||||||
let next_batch = filter_batch(client, next_batch, entries, updated).await;
|
// Run Only on re-allocation or if entries were filtered
|
||||||
|
Some(batch) if batch.size as usize != entries.len() || force_update => {
|
||||||
|
// Filter next batch: remove requests that were stopped and update blocks/slots
|
||||||
|
let (filtered_batch, terminated_generations) =
|
||||||
|
filter_batch(client, batch, entries, &terminated_entries).await;
|
||||||
|
send_terminated_generations(
|
||||||
|
terminated_generations,
|
||||||
|
terminated_entries,
|
||||||
|
filtered_stream_responses,
|
||||||
|
);
|
||||||
|
|
||||||
|
filtered_batch
|
||||||
|
}
|
||||||
|
batch => batch,
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
if let Some(concat_duration) = timings.concat {
|
||||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||||
@ -320,27 +368,20 @@ async fn decode(
|
|||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn filter_batch(
|
async fn filter_batch(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
next_batch: Option<CachedBatch>,
|
batch: CachedBatch,
|
||||||
entries: &IntMap<u64, Entry>,
|
entries: &IntMap<u64, Entry>,
|
||||||
force_update: bool,
|
terminated_entries: &IntMap<u64, Entry>,
|
||||||
) -> Option<CachedBatch> {
|
) -> (Option<CachedBatch>, Vec<TerminatedGeneration>) {
|
||||||
let batch = next_batch?;
|
|
||||||
|
|
||||||
// No need to filter
|
|
||||||
if batch.size as usize == entries.len() && !force_update {
|
|
||||||
return Some(batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
let id = batch.id;
|
let id = batch.id;
|
||||||
if entries.is_empty() {
|
if entries.is_empty() && terminated_entries.is_empty() {
|
||||||
// All requests have been filtered out
|
// All requests have been filtered out
|
||||||
// Next batch is now empty
|
// Next batch is now empty
|
||||||
// Clear it from the Python shards cache
|
// Clear it from the Python shards cache
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
client.clear_cache(Some(id)).await.unwrap();
|
client.clear_cache(Some(id)).await.unwrap();
|
||||||
None
|
Default::default()
|
||||||
} else {
|
} else {
|
||||||
// Filter Python shard cache
|
// Collect new blocks/slots
|
||||||
let updated_requests = entries
|
let updated_requests = entries
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(request_id, entry)| {
|
.map(|(request_id, entry)| {
|
||||||
@ -348,7 +389,7 @@ async fn filter_batch(
|
|||||||
.block_allocation
|
.block_allocation
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
|
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
|
||||||
.unwrap_or((Vec::new(), Vec::new()));
|
.unwrap_or_default();
|
||||||
|
|
||||||
KeptRequest {
|
KeptRequest {
|
||||||
id: *request_id,
|
id: *request_id,
|
||||||
@ -358,111 +399,207 @@ async fn filter_batch(
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Filter Python shard cache
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
client
|
client
|
||||||
.filter_batch(id, updated_requests, Vec::new())
|
.filter_batch(
|
||||||
|
id,
|
||||||
|
updated_requests,
|
||||||
|
terminated_entries.keys().map(|v| *v).collect(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
///
|
||||||
/// and filter entries
|
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
fn send_terminated_generations(
|
||||||
generations.into_iter().for_each(|generation| {
|
terminated_generations: Vec<TerminatedGeneration>,
|
||||||
|
terminated_entries: IntMap<u64, Entry>,
|
||||||
|
mut stream_responses: IntMap<u64, Vec<InferStreamResponse>>,
|
||||||
|
) {
|
||||||
|
// Receive final message for terminated generations
|
||||||
|
'terminated_generations: for terminated_generation in terminated_generations {
|
||||||
|
let id = terminated_generation.id;
|
||||||
|
// Get entry for this generation
|
||||||
|
let entry = terminated_entries
|
||||||
|
.get(&id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
// Get previous `InferStreamResponse` for this generation
|
||||||
|
let stream_responses = stream_responses
|
||||||
|
.remove(&id)
|
||||||
|
.expect("ID not found in stream_responses. This is a bug.");
|
||||||
|
|
||||||
|
// Peekable iterator to know when we are at the last `InferStreamResponse`
|
||||||
|
let mut iterator = stream_responses.into_iter().peekable();
|
||||||
|
|
||||||
|
while let Some(stream_response) = iterator.next() {
|
||||||
|
let response = if iterator.peek().is_none() {
|
||||||
|
// Last `InferStreamResponse::Intermediate`
|
||||||
|
let (token, top_tokens) = match stream_response {
|
||||||
|
InferStreamResponse::Intermediate { token, top_tokens } => (token, top_tokens),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
// Modify it to be a `InferStreamResponse::End` with the new OutOfResources finish
|
||||||
|
// reason
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
generated_text: GeneratedText::from(
|
||||||
|
terminated_generation
|
||||||
|
.generated_text
|
||||||
|
.clone()
|
||||||
|
.expect("Generated Text is None. This is a bug."),
|
||||||
|
),
|
||||||
|
queued: entry.queue_time,
|
||||||
|
start: entry.batch_time.unwrap(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stream_response
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send responses
|
||||||
|
if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
|
err
|
||||||
|
}) {
|
||||||
|
continue 'terminated_generations;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send `InferStreamResponse::End` to `Infer` for finished entries and remove them from `entries`
|
||||||
|
/// Returns filtered `InferStreamResponse::Intermediate` generations
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn filter_send_ended_generations(
|
||||||
|
generations: Vec<Generation>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> IntMap<u64, Vec<InferStreamResponse>> {
|
||||||
|
generations.into_iter().filter_map(|generation| {
|
||||||
let id = generation.request_id;
|
let id = generation.request_id;
|
||||||
// Get entry
|
// Get entry
|
||||||
// We can `expect` here as the request id should always be in the entries
|
// We can `expect` here as the request id should always be in the entries
|
||||||
let entry = entries
|
let entry = entries
|
||||||
.get_mut(&id)
|
.get_mut(&id)
|
||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
entry.cache_length = generation.cache_length;
|
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
// Send generation responses back to the infer task
|
|
||||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
|
||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
|
||||||
tracing::error!("Entry response channel error.");
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
||||||
err
|
|
||||||
}).unwrap_or(true);
|
|
||||||
if stopped {
|
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if block allocations need to be extended
|
|
||||||
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
|
||||||
let ids: Vec<u64> = entries
|
|
||||||
.iter()
|
|
||||||
.filter_map(|(id, entry)| {
|
|
||||||
entry
|
|
||||||
.block_allocation
|
|
||||||
.as_ref()
|
|
||||||
.map(|block_allocation| {
|
|
||||||
if entry.cache_length > block_allocation.len() as u32 {
|
|
||||||
// We need to re-allocate
|
|
||||||
Some(*id)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap_or(None)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
for id in ids.iter() {
|
|
||||||
// Get entry
|
|
||||||
// We can `expect` here as the request id should always be in the entries
|
|
||||||
let extension = {
|
|
||||||
let entry = entries
|
|
||||||
.get_mut(id)
|
|
||||||
.expect("ID not found in entries. This is a bug.");
|
|
||||||
entry
|
|
||||||
.block_allocation
|
|
||||||
.as_mut()
|
|
||||||
.expect("We checked that the block allocation exists above")
|
|
||||||
.extend()
|
|
||||||
};
|
|
||||||
|
|
||||||
if extension.is_err() {
|
|
||||||
let entry = entries
|
|
||||||
.remove(id)
|
|
||||||
.expect("ID not found in entries. This is a bug.");
|
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
|
||||||
let err = InferError::OutOfPages;
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
||||||
entry.response_tx.send(Err(err)).unwrap_or(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If ids is not empty, we need to update
|
|
||||||
!ids.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send responses through the `entry` response channel
|
|
||||||
fn send_responses(
|
|
||||||
generation: Generation,
|
|
||||||
entry: &Entry,
|
|
||||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
return Ok(true);
|
// Remove from entries and filter
|
||||||
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut stopped = false;
|
// Update cache length
|
||||||
|
entry.cache_length = generation.cache_length;
|
||||||
|
|
||||||
|
let (finished, stream_responses) = map_generation(generation, entry);
|
||||||
|
// If the generation has ended for this request, we send the responses to the channel and
|
||||||
|
// remove the entry to drop it and free its blocks
|
||||||
|
if finished {
|
||||||
|
let _ = send_stream_responses(stream_responses, entry).map_err(|err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
|
err
|
||||||
|
});
|
||||||
|
// Remove from entries and filter
|
||||||
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some((id, stream_responses))
|
||||||
|
}).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send `InferStreamResponse` to `Infer` through an `Entry` response channel
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn send_stream_responses(
|
||||||
|
stream_responses: Vec<InferStreamResponse>,
|
||||||
|
entry: &Entry,
|
||||||
|
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
|
for response in stream_responses {
|
||||||
|
entry.response_tx.send(Ok(response))?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if block allocations need to be extended
|
||||||
|
/// If we don't have enough blocks, request will be filtered with be added to an IntMap of
|
||||||
|
/// terminated entries.
|
||||||
|
/// If at least one entry allocation was extended, we return true to force an update
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn filter_send_update_allocations(
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
stream_responses: &mut IntMap<u64, Vec<InferStreamResponse>>,
|
||||||
|
) -> (bool, IntMap<u64, Entry>) {
|
||||||
|
let mut updated = false;
|
||||||
|
|
||||||
|
let ids: Vec<u64> = entries.keys().map(|v| *v).collect();
|
||||||
|
let mut terminated_entries =
|
||||||
|
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
|
for id in &ids {
|
||||||
|
let entry = entries
|
||||||
|
.get_mut(id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
if let Some(block_allocation) = entry.block_allocation.as_mut() {
|
||||||
|
// Check if allocation can handle the current cache_length
|
||||||
|
if entry.cache_length > block_allocation.len() as u32 {
|
||||||
|
updated = true;
|
||||||
|
|
||||||
|
// Extend allocation by asking for a new block
|
||||||
|
if let Err(err) = block_allocation.extend() {
|
||||||
|
// Failed to extend allocation
|
||||||
|
tracing::error!("Failed to extend allocation: {err}");
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_resources");
|
||||||
|
|
||||||
|
// Remove entry
|
||||||
|
let mut entry = entries
|
||||||
|
.remove(id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
// Clear block allocation
|
||||||
|
entry.block_allocation = None;
|
||||||
|
// Add it to terminated entries
|
||||||
|
terminated_entries.insert(*id, entry);
|
||||||
|
// Skip the rest of the logic to not send the intermediate messages
|
||||||
|
// This entry will be terminated and we will need to edit the last intermediate
|
||||||
|
// response to add the complete generated text
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let stream_response = stream_responses
|
||||||
|
.remove(id)
|
||||||
|
.expect("ID not found in stream_responses. This is a bug.");
|
||||||
|
|
||||||
|
// Send intermediate responses
|
||||||
|
if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
|
err
|
||||||
|
}) {
|
||||||
|
// Sending failed, remove entry
|
||||||
|
entries
|
||||||
|
.remove(id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(updated, terminated_entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>`
|
||||||
|
fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec<InferStreamResponse>) {
|
||||||
|
let mut finished = false;
|
||||||
|
let mut stream_responses = Vec::with_capacity(16);
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Create Token objects
|
// Create Token objects
|
||||||
@ -475,10 +612,8 @@ fn send_responses(
|
|||||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Send message
|
// Push to stream_responses
|
||||||
entry
|
stream_responses.push(InferStreamResponse::Prefill(prefill_tokens));
|
||||||
.response_tx
|
|
||||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create last Token
|
// Create last Token
|
||||||
@ -520,26 +655,24 @@ fn send_responses(
|
|||||||
match (&generation.generated_text, iterator.peek()) {
|
match (&generation.generated_text, iterator.peek()) {
|
||||||
(Some(generated_text), None) => {
|
(Some(generated_text), None) => {
|
||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
finished = true;
|
||||||
// Send message
|
// Push to stream_responses
|
||||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
stream_responses.push(InferStreamResponse::End {
|
||||||
token,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
generated_text: GeneratedText::from(generated_text.clone()),
|
generated_text: GeneratedText::from(generated_text.clone()),
|
||||||
queued: entry.queue_time,
|
queued: entry.queue_time,
|
||||||
start: entry.batch_time.unwrap(),
|
start: entry.batch_time.unwrap(),
|
||||||
}))?;
|
});
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
// Send message
|
// Push to stream_responses
|
||||||
entry
|
stream_responses.push(InferStreamResponse::Intermediate { token, top_tokens });
|
||||||
.response_tx
|
|
||||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(stopped)
|
(finished, stream_responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send errors to Infer for all `entries`
|
/// Send errors to Infer for all `entries`
|
||||||
|
@ -402,10 +402,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
model: "FlashCausalLM",
|
model: "FlashCausalLM",
|
||||||
kept_requests: List[generate_pb2.KeptRequest],
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
terminated_request_ids: List[int],
|
terminated_request_ids: List[int],
|
||||||
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]:
|
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||||
if len(kept_requests) == 0:
|
|
||||||
raise ValueError("Batch must have at least one request")
|
|
||||||
|
|
||||||
terminated_generations = []
|
terminated_generations = []
|
||||||
for request_id in terminated_request_ids:
|
for request_id in terminated_request_ids:
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
@ -421,13 +418,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
terminated_generations.append(
|
||||||
output_text,
|
generate_pb2.TerminatedGeneration(
|
||||||
stopping_criteria.current_tokens,
|
id=request_id,
|
||||||
generate_pb2.FINISH_REASON_TERMINATED,
|
generated_text=generate_pb2.GeneratedText(
|
||||||
seed if do_sample else None,
|
text=output_text,
|
||||||
|
generated_tokens=stopping_criteria.current_tokens,
|
||||||
|
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||||
|
seed=seed if do_sample else None,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
terminated_generations.append(generated_text)
|
)
|
||||||
|
if not kept_requests:
|
||||||
|
return None, terminated_generations
|
||||||
|
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class Batch(ABC):
|
|||||||
model,
|
model,
|
||||||
kept_requests: List[generate_pb2.KeptRequest],
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
terminated_request_ids: List[int],
|
terminated_request_ids: List[int],
|
||||||
) -> Tuple["Batch", List[generate_pb2.GeneratedText]]:
|
) -> Tuple[Optional["Batch"], List[generate_pb2.TerminatedGeneration]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -86,10 +86,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
filtered_batch, terminated_generations = batch.filter(
|
filtered_batch, terminated_generations = batch.filter(
|
||||||
self.model, request.kept_requests, request.terminated_request_ids
|
self.model, request.kept_requests, request.terminated_request_ids
|
||||||
)
|
)
|
||||||
|
if filtered_batch is not None:
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(
|
return generate_pb2.FilterBatchResponse(
|
||||||
batch=filtered_batch.to_pb(), terminated_generations=terminated_generations
|
batch=filtered_batch.to_pb() if filtered_batch is not None else None,
|
||||||
|
terminated_generations=terminated_generations,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
Loading…
Reference in New Issue
Block a user