mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
use relaxed
This commit is contained in:
parent
a8e6a49946
commit
1ce58375b9
8
Cargo.lock
generated
8
Cargo.lock
generated
@ -2808,7 +2808,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "1.1.0"
|
version = "1.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap",
|
"clap",
|
||||||
@ -2829,7 +2829,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "1.1.0"
|
version = "1.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
@ -2845,7 +2845,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "1.1.0"
|
version = "1.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -2861,7 +2861,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "1.1.0"
|
version = "1.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
|
@ -257,7 +257,6 @@ impl Infer {
|
|||||||
///
|
///
|
||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
@ -276,7 +275,6 @@ async fn batching_task(
|
|||||||
// Get the next batch from the queue
|
// Get the next batch from the queue
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
tracing::debug!("First batch");
|
|
||||||
while let Some((mut entries, batch, span)) = queue
|
while let Some((mut entries, batch, span)) = queue
|
||||||
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
||||||
.await
|
.await
|
||||||
@ -380,8 +378,6 @@ async fn prefill(
|
|||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||||
|
|
||||||
tracing::debug!("Prefill");
|
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
// Update health
|
// Update health
|
||||||
@ -419,15 +415,12 @@ async fn decode(
|
|||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||||
|
|
||||||
tracing::debug!("Decode");
|
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
// Update health
|
// Update health
|
||||||
generation_health.store(true, Ordering::Relaxed);
|
generation_health.store(true, Ordering::Relaxed);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
tracing::debug!("filter batch");
|
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
@ -473,11 +466,9 @@ async fn filter_batch(
|
|||||||
// Next batch is now empty
|
// Next batch is now empty
|
||||||
// Clear it from the Python shards cache
|
// Clear it from the Python shards cache
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
tracing::info!("Call python clear cache");
|
|
||||||
client.clear_cache(Some(id)).await.unwrap();
|
client.clear_cache(Some(id)).await.unwrap();
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
tracing::info!("Call python filter batch");
|
|
||||||
// Filter Python shard cache
|
// Filter Python shard cache
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||||
@ -522,7 +513,6 @@ fn send_responses(
|
|||||||
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_disconnected() {
|
if entry.response_tx.is_disconnected() {
|
||||||
tracing::debug!("Disconnected");
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
@ -531,7 +521,6 @@ fn send_responses(
|
|||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Send message
|
// Send message
|
||||||
tracing::debug!("Send prefill");
|
|
||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send_timeout(
|
||||||
Ok(InferStreamResponse::Prefill(prefill_tokens)),
|
Ok(InferStreamResponse::Prefill(prefill_tokens)),
|
||||||
Duration::from_millis(10),
|
Duration::from_millis(10),
|
||||||
@ -548,7 +537,6 @@ fn send_responses(
|
|||||||
|
|
||||||
// generation.top_tokens
|
// generation.top_tokens
|
||||||
|
|
||||||
tracing::debug!("Top tokens");
|
|
||||||
let mut top_tokens = Vec::new();
|
let mut top_tokens = Vec::new();
|
||||||
if let Some(top_tokens_) = generation.top_tokens {
|
if let Some(top_tokens_) = generation.top_tokens {
|
||||||
top_tokens.extend(
|
top_tokens.extend(
|
||||||
@ -571,7 +559,6 @@ fn send_responses(
|
|||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
stopped = true;
|
||||||
// Send message
|
// Send message
|
||||||
tracing::debug!("send final");
|
|
||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send_timeout(
|
||||||
Ok(InferStreamResponse::End {
|
Ok(InferStreamResponse::End {
|
||||||
token,
|
token,
|
||||||
@ -583,7 +570,6 @@ fn send_responses(
|
|||||||
Duration::from_millis(10),
|
Duration::from_millis(10),
|
||||||
)?;
|
)?;
|
||||||
} else {
|
} else {
|
||||||
tracing::debug!("send intermediate");
|
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send_timeout(
|
||||||
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
||||||
|
@ -82,9 +82,7 @@ impl Queue {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
let response = response_receiver.await.unwrap();
|
response_receiver.await.unwrap()
|
||||||
tracing::debug!("Next batch: {}", response.is_some());
|
|
||||||
response
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user