This commit is contained in:
OlivierDehaene 2023-10-20 12:55:50 +02:00
parent df419a21e0
commit a8e6a49946

View File

@ -427,6 +427,7 @@ async fn decode(
generation_health.store(true, Ordering::Relaxed); generation_health.store(true, Ordering::Relaxed);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
tracing::debug!("filter batch");
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
@ -472,9 +473,11 @@ async fn filter_batch(
// Next batch is now empty // Next batch is now empty
// Clear it from the Python shards cache // Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
tracing::info!("Call python clear cache");
client.clear_cache(Some(id)).await.unwrap(); client.clear_cache(Some(id)).await.unwrap();
None None
} else { } else {
tracing::info!("Call python filter batch");
// Filter Python shard cache // Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap() client.filter_batch(id, batch.request_ids).await.unwrap()
@ -519,6 +522,7 @@ fn send_responses(
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_disconnected() { if entry.response_tx.is_disconnected() {
tracing::debug!("Disconnected");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
return Ok(true); return Ok(true);
} }
@ -527,6 +531,7 @@ fn send_responses(
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
tracing::debug!("Send prefill");
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::Prefill(prefill_tokens)), Ok(InferStreamResponse::Prefill(prefill_tokens)),
Duration::from_millis(10), Duration::from_millis(10),
@ -543,6 +548,7 @@ fn send_responses(
// generation.top_tokens // generation.top_tokens
tracing::debug!("Top tokens");
let mut top_tokens = Vec::new(); let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens { if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend( top_tokens.extend(
@ -565,6 +571,7 @@ fn send_responses(
// Generation has ended // Generation has ended
stopped = true; stopped = true;
// Send message // Send message
tracing::debug!("send final");
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::End { Ok(InferStreamResponse::End {
token, token,
@ -576,6 +583,7 @@ fn send_responses(
Duration::from_millis(10), Duration::from_millis(10),
)?; )?;
} else { } else {
tracing::debug!("send intermediate");
// Send message // Send message
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::Intermediate { token, top_tokens }), Ok(InferStreamResponse::Intermediate { token, top_tokens }),