use relaxed

This commit is contained in:
OlivierDehaene 2023-10-20 14:33:19 +02:00
parent a8e6a49946
commit 1ce58375b9
3 changed files with 5 additions and 21 deletions

8
Cargo.lock generated
View File

@ -2808,7 +2808,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -2829,7 +2829,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2845,7 +2845,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -2861,7 +2861,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "1.1.0" version = "1.1.1"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",

View File

@ -257,7 +257,6 @@ impl Infer {
/// ///
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[instrument(skip_all)]
async fn batching_task( async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
waiting_served_ratio: f32, waiting_served_ratio: f32,
@ -276,7 +275,6 @@ async fn batching_task(
// Get the next batch from the queue // Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
tracing::debug!("First batch");
while let Some((mut entries, batch, span)) = queue while let Some((mut entries, batch, span)) = queue
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
.await .await
@ -380,8 +378,6 @@ async fn prefill(
let batch_id = batch.id; let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
tracing::debug!("Prefill");
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health // Update health
@ -419,15 +415,12 @@ async fn decode(
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
tracing::debug!("Decode");
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health // Update health
generation_health.store(true, Ordering::Relaxed); generation_health.store(true, Ordering::Relaxed);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
tracing::debug!("filter batch");
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
@ -473,11 +466,9 @@ async fn filter_batch(
// Next batch is now empty // Next batch is now empty
// Clear it from the Python shards cache // Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
tracing::info!("Call python clear cache");
client.clear_cache(Some(id)).await.unwrap(); client.clear_cache(Some(id)).await.unwrap();
None None
} else { } else {
tracing::info!("Call python filter batch");
// Filter Python shard cache // Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap() client.filter_batch(id, batch.request_ids).await.unwrap()
@ -522,7 +513,6 @@ fn send_responses(
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_disconnected() { if entry.response_tx.is_disconnected() {
tracing::debug!("Disconnected");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
return Ok(true); return Ok(true);
} }
@ -531,7 +521,6 @@ fn send_responses(
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
tracing::debug!("Send prefill");
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::Prefill(prefill_tokens)), Ok(InferStreamResponse::Prefill(prefill_tokens)),
Duration::from_millis(10), Duration::from_millis(10),
@ -548,7 +537,6 @@ fn send_responses(
// generation.top_tokens // generation.top_tokens
tracing::debug!("Top tokens");
let mut top_tokens = Vec::new(); let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens { if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend( top_tokens.extend(
@ -571,7 +559,6 @@ fn send_responses(
// Generation has ended // Generation has ended
stopped = true; stopped = true;
// Send message // Send message
tracing::debug!("send final");
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::End { Ok(InferStreamResponse::End {
token, token,
@ -583,7 +570,6 @@ fn send_responses(
Duration::from_millis(10), Duration::from_millis(10),
)?; )?;
} else { } else {
tracing::debug!("send intermediate");
// Send message // Send message
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::Intermediate { token, top_tokens }), Ok(InferStreamResponse::Intermediate { token, top_tokens }),

View File

@ -82,9 +82,7 @@ impl Queue {
.unwrap(); .unwrap();
// Await on response channel // Await on response channel
// Unwrap is safe here // Unwrap is safe here
let response = response_receiver.await.unwrap(); response_receiver.await.unwrap()
tracing::debug!("Next batch: {}", response.is_some());
response
} }
} }