mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
Added model name label to metrics and added an optional argument --served-model-name
This commit is contained in:
parent
5eec3a8bb6
commit
380e73dba9
@ -19,6 +19,10 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
|
||||||
|
/// Name under which the model is served. Defaults to `model_id` if not provided.
|
||||||
|
#[clap(long, env)]
|
||||||
|
served_model_name: Option<String>,
|
||||||
|
|
||||||
/// Revision of the model.
|
/// Revision of the model.
|
||||||
#[clap(default_value = "main", long, env)]
|
#[clap(default_value = "main", long, env)]
|
||||||
revision: String,
|
revision: String,
|
||||||
@ -152,6 +156,10 @@ struct Args {
|
|||||||
async fn main() -> Result<(), RouterError> {
|
async fn main() -> Result<(), RouterError> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let served_model_name = args.served_model_name
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| args.model_id.clone());
|
||||||
|
|
||||||
logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
|
logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
|
||||||
|
|
||||||
let n_threads = match args.n_threads {
|
let n_threads = match args.n_threads {
|
||||||
@ -264,6 +272,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
args.max_client_batch_size,
|
args.max_client_batch_size,
|
||||||
args.usage_stats,
|
args.usage_stats,
|
||||||
args.payload_limit,
|
args.payload_limit,
|
||||||
|
served_model_name
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -45,6 +45,8 @@ struct Args {
|
|||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
served_model_name: Option<String>,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -227,6 +229,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
model_id,
|
model_id,
|
||||||
|
served_model_name,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
@ -239,6 +242,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
payload_limit,
|
payload_limit,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
let served_model_name = args.served_model_name
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| args.model_id.clone());
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
@ -318,6 +325,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
served_model_name,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -34,6 +34,7 @@ impl BackendV2 {
|
|||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
served_model_name: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
|
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
|
||||||
@ -44,7 +45,7 @@ impl BackendV2 {
|
|||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate, served_model_name.clone());
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
@ -57,6 +58,7 @@ impl BackendV2 {
|
|||||||
max_batch_size,
|
max_batch_size,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
batching_task_notifier.clone(),
|
batching_task_notifier.clone(),
|
||||||
|
served_model_name.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@ -128,6 +130,7 @@ pub(crate) async fn batching_task(
|
|||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
notifier: Arc<Notify>,
|
notifier: Arc<Notify>,
|
||||||
|
served_model_name: String,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
@ -146,7 +149,7 @@ pub(crate) async fn batching_task(
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, served_model_name.clone())
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
@ -158,8 +161,8 @@ pub(crate) async fn batching_task(
|
|||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let batch_max_tokens = batch.max_tokens;
|
let batch_max_tokens = batch.max_tokens;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64);
|
||||||
|
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
@ -180,10 +183,10 @@ pub(crate) async fn batching_task(
|
|||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone())
|
||||||
.increment(1);
|
.increment(1);
|
||||||
} else {
|
} else {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone())
|
||||||
.increment(1);
|
.increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,7 +202,7 @@ pub(crate) async fn batching_task(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries, served_model_name.clone())
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
@ -225,13 +228,13 @@ pub(crate) async fn batching_task(
|
|||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = decode(&mut client, batches, &mut entries)
|
cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone())
|
||||||
.instrument(next_batch_span)
|
.instrument(next_batch_span)
|
||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -241,36 +244,37 @@ async fn prefill(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
served_model_name: String,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries, served_model_name.clone());
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill", "model_name" => served_model_name.clone())
|
||||||
.record(timings.forward.as_secs_f64());
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||||
.record(timings.decode.as_secs_f64());
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill", "model_name" => served_model_name.clone())
|
||||||
.record(start_time.elapsed().as_secs_f64());
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -281,33 +285,34 @@ async fn decode(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
served_model_name: String,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries, served_model_name.clone());
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
if let Some(concat_duration) = timings.concat {
|
||||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(concat_duration.as_secs_f64());
|
.record(concat_duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(timings.forward.as_secs_f64());
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(timings.decode.as_secs_f64());
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(start_time.elapsed().as_secs_f64());
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -316,7 +321,7 @@ async fn decode(
|
|||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -358,7 +363,7 @@ async fn filter_batch(
|
|||||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
/// and filter entries
|
/// and filter entries
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
|
||||||
generations.into_iter().for_each(|generation| {
|
generations.into_iter().for_each(|generation| {
|
||||||
let id = generation.request_id;
|
let id = generation.request_id;
|
||||||
// Get entry
|
// Get entry
|
||||||
@ -372,9 +377,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
// Send generation responses back to the infer task
|
// Send generation responses back to the infer task
|
||||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
let stopped = send_responses(generation, entry, served_model_name.clone()).inspect_err(|_err| {
|
||||||
tracing::error!("Entry response channel error.");
|
tracing::error!("Entry response channel error.");
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
if stopped {
|
if stopped {
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
@ -386,10 +391,11 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
fn send_responses(
|
fn send_responses(
|
||||||
generation: Generation,
|
generation: Generation,
|
||||||
entry: &Entry,
|
entry: &Entry,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -415,7 +421,7 @@ fn send_responses(
|
|||||||
// Create last Token
|
// Create last Token
|
||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
let n = tokens_.ids.len();
|
let n = tokens_.ids.len();
|
||||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64);
|
||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -39,6 +39,7 @@ pub async fn connect_backend(
|
|||||||
max_batch_total_tokens: Option<u32>,
|
max_batch_total_tokens: Option<u32>,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||||
// Helper function
|
// Helper function
|
||||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
@ -119,6 +120,7 @@ pub async fn connect_backend(
|
|||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
shard_info.speculate,
|
shard_info.speculate,
|
||||||
|
served_model_name,
|
||||||
);
|
);
|
||||||
|
|
||||||
tracing::info!("Using backend V3");
|
tracing::info!("Using backend V3");
|
||||||
|
@ -7,6 +7,9 @@ use thiserror::Error;
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
#[clap(long, env)]
|
||||||
|
served_model_name: String,
|
||||||
|
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
command: Option<Commands>,
|
command: Option<Commands>,
|
||||||
|
|
||||||
@ -83,8 +86,11 @@ enum Commands {
|
|||||||
async fn main() -> Result<(), RouterError> {
|
async fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
// Get args
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let _served_model_name = args.served_model_name.clone();
|
||||||
|
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
|
served_model_name,
|
||||||
command,
|
command,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
@ -170,6 +176,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
served_model_name.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
@ -198,6 +205,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
served_model_name.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -43,6 +43,7 @@ impl Queue {
|
|||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
served_model_name: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
@ -54,6 +55,7 @@ impl Queue {
|
|||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
|
served_model_name,
|
||||||
));
|
));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
@ -104,6 +106,7 @@ async fn queue_task(
|
|||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
|
served_model_name: String,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||||
|
|
||||||
@ -111,7 +114,7 @@ async fn queue_task(
|
|||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => {
|
QueueCommand::Append(entry, span) => {
|
||||||
span.in_scope(|| state.append(*entry));
|
span.in_scope(|| state.append(*entry));
|
||||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).increment(1.0);
|
||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
@ -122,9 +125,9 @@ async fn queue_task(
|
|||||||
span,
|
span,
|
||||||
} => span.in_scope(|| {
|
} => span.in_scope(|| {
|
||||||
let next_batch =
|
let next_batch =
|
||||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
state.next_batch(min_size, max_size, prefill_token_budget, token_budget, served_model_name.clone());
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).set(state.entries.len() as f64);
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -191,6 +194,7 @@ impl State {
|
|||||||
max_size: Option<usize>,
|
max_size: Option<usize>,
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
|
served_model_name: String,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
tracing::debug!("No queue");
|
tracing::debug!("No queue");
|
||||||
@ -232,7 +236,7 @@ impl State {
|
|||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||||
tracing::debug!("Dropping entry");
|
tracing::debug!("Dropping entry");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -340,7 +344,7 @@ impl State {
|
|||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
metrics::histogram!("tgi_batch_next_size", "model_name" => served_model_name.clone()).record(batch.size as f64);
|
||||||
|
|
||||||
Some((batch_entries, batch, next_batch_span))
|
Some((batch_entries, batch, next_batch_span))
|
||||||
}
|
}
|
||||||
@ -466,21 +470,23 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
assert!(state.next_batch(None, None, 1, 1, served_model_name.clone()).is_none());
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1, served_model_name.clone()).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_min_size() {
|
fn test_next_batch_min_size() {
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 2, 2, served_model_name.clone()).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -496,7 +502,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
assert!(state.next_batch(Some(2), None, 2, 2, served_model_name.clone()).is_none());
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
assert_eq!(state.next_id, 3);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
@ -506,13 +512,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_max_size() {
|
fn test_next_batch_max_size() {
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2, served_model_name.clone()).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
@ -526,13 +533,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_token_budget() {
|
fn test_next_batch_token_budget() {
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 1, 1, served_model_name.clone()).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -545,7 +553,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 3, 3, served_model_name.clone()).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -559,14 +567,16 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
@ -574,7 +584,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -607,7 +618,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -623,7 +635,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
|
||||||
|
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -648,7 +662,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, 1, None, 2);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
|
||||||
|
let queue = Queue::new(false, 1, None, 2, served_model_name.clone());
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -667,7 +683,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
|
||||||
|
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ impl BackendV3 {
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
shard_info: InfoResponse,
|
shard_info: InfoResponse,
|
||||||
|
served_model_name: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if shard_info.support_chunking {
|
if shard_info.support_chunking {
|
||||||
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||||
@ -49,6 +50,7 @@ impl BackendV3 {
|
|||||||
shard_info.speculate,
|
shard_info.speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
shard_info.support_chunking,
|
shard_info.support_chunking,
|
||||||
|
served_model_name.clone(),
|
||||||
);
|
);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
@ -63,6 +65,7 @@ impl BackendV3 {
|
|||||||
shard_info.support_chunking,
|
shard_info.support_chunking,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
batching_task_notifier.clone(),
|
batching_task_notifier.clone(),
|
||||||
|
served_model_name.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@ -136,6 +139,7 @@ pub(crate) async fn batching_task(
|
|||||||
support_chunking: bool,
|
support_chunking: bool,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
notifier: Arc<Notify>,
|
notifier: Arc<Notify>,
|
||||||
|
served_model_name: String,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
@ -154,7 +158,7 @@ pub(crate) async fn batching_task(
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, None, &mut entries, served_model_name.clone())
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
@ -167,8 +171,8 @@ pub(crate) async fn batching_task(
|
|||||||
let batch_max_tokens = batch.max_tokens;
|
let batch_max_tokens = batch.max_tokens;
|
||||||
let current_tokens = batch.current_tokens;
|
let current_tokens = batch.current_tokens;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64);
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
|
||||||
@ -207,13 +211,13 @@ pub(crate) async fn batching_task(
|
|||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone())
|
||||||
.increment(1);
|
.increment(1);
|
||||||
} else {
|
} else {
|
||||||
let counter = if support_chunking {
|
let counter = if support_chunking {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
|
metrics::counter!("tgi_batch_concat", "reason" => "chunking", "model_name" => served_model_name.clone())
|
||||||
} else {
|
} else {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone())
|
||||||
};
|
};
|
||||||
counter.increment(1);
|
counter.increment(1);
|
||||||
}
|
}
|
||||||
@ -226,7 +230,7 @@ pub(crate) async fn batching_task(
|
|||||||
entries.extend(new_entries);
|
entries.extend(new_entries);
|
||||||
// Generate one token for both the cached batch and the new batch
|
// Generate one token for both the cached batch and the new batch
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
prefill(&mut client, new_batch, cached_batch, &mut entries, served_model_name.clone())
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
if new_cached_batch.is_none() {
|
if new_cached_batch.is_none() {
|
||||||
@ -250,7 +254,7 @@ pub(crate) async fn batching_task(
|
|||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
prefill(&mut client, new_batch, None, &mut new_entries)
|
prefill(&mut client, new_batch, None, &mut new_entries, served_model_name.clone())
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
if new_cached_batch.is_some() {
|
if new_cached_batch.is_some() {
|
||||||
@ -282,13 +286,13 @@ pub(crate) async fn batching_task(
|
|||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = decode(&mut client, batches, &mut entries)
|
cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone())
|
||||||
.instrument(next_batch_span)
|
.instrument(next_batch_span)
|
||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -299,40 +303,41 @@ async fn prefill(
|
|||||||
batch: Batch,
|
batch: Batch,
|
||||||
cached_batch: Option<CachedBatch>,
|
cached_batch: Option<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
served_model_name: String,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||||
|
|
||||||
match client.prefill(batch, cached_batch).await {
|
match client.prefill(batch, cached_batch).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries, served_model_name.clone());
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
if let Some(concat_duration) = timings.concat {
|
||||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(concat_duration.as_secs_f64());
|
.record(concat_duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||||
.record(timings.forward.as_secs_f64());
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||||
.record(timings.decode.as_secs_f64());
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||||
.record(start_time.elapsed().as_secs_f64());
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries, served_model_name.clone());
|
||||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -343,33 +348,34 @@ async fn decode(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
served_model_name: String,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries, served_model_name.clone());
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
if let Some(concat_duration) = timings.concat {
|
||||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(concat_duration.as_secs_f64());
|
.record(concat_duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(timings.forward.as_secs_f64());
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(timings.decode.as_secs_f64());
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||||
.record(start_time.elapsed().as_secs_f64());
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -377,8 +383,8 @@ async fn decode(
|
|||||||
for id in batch_ids {
|
for id in batch_ids {
|
||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
send_errors(err, entries);
|
send_errors(err, entries, served_model_name.clone());
|
||||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -420,7 +426,7 @@ async fn filter_batch(
|
|||||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
/// and filter entries
|
/// and filter entries
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
|
||||||
generations.into_iter().for_each(|generation| {
|
generations.into_iter().for_each(|generation| {
|
||||||
let id = generation.request_id;
|
let id = generation.request_id;
|
||||||
// Get entry
|
// Get entry
|
||||||
@ -434,9 +440,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
// Send generation responses back to the infer task
|
// Send generation responses back to the infer task
|
||||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
let stopped = send_responses(generation, entry, served_model_name.clone()).inspect_err(|_err| {
|
||||||
tracing::error!("Entry response channel error.");
|
tracing::error!("Entry response channel error.");
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
if stopped {
|
if stopped {
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
@ -448,10 +454,11 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
fn send_responses(
|
fn send_responses(
|
||||||
generation: Generation,
|
generation: Generation,
|
||||||
entry: &Entry,
|
entry: &Entry,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -477,7 +484,7 @@ fn send_responses(
|
|||||||
// Create last Token
|
// Create last Token
|
||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
let n = tokens_.ids.len();
|
let n = tokens_.ids.len();
|
||||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64);
|
||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -537,12 +544,12 @@ fn send_responses(
|
|||||||
|
|
||||||
/// Send errors to Infer for all `entries`
|
/// Send errors to Infer for all `entries`
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
|
||||||
entries.drain().for_each(|(_, entry)| {
|
entries.drain().for_each(|(_, entry)| {
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
let err = InferError::GenerationError(error.to_string());
|
let err = InferError::GenerationError(error.to_string());
|
||||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "generation", "model_name" => served_model_name.clone()).increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
@ -54,6 +54,7 @@ pub async fn connect_backend(
|
|||||||
max_batch_total_tokens: Option<u32>,
|
max_batch_total_tokens: Option<u32>,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||||
// Helper function
|
// Helper function
|
||||||
let check_max_batch_total_tokens = |(
|
let check_max_batch_total_tokens = |(
|
||||||
@ -161,6 +162,7 @@ pub async fn connect_backend(
|
|||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
shard_info,
|
shard_info,
|
||||||
|
served_model_name,
|
||||||
);
|
);
|
||||||
|
|
||||||
tracing::info!("Using backend V3");
|
tracing::info!("Using backend V3");
|
||||||
|
@ -7,6 +7,9 @@ use thiserror::Error;
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
#[clap(long, env)]
|
||||||
|
served_model_name: String,
|
||||||
|
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
command: Option<Commands>,
|
command: Option<Commands>,
|
||||||
|
|
||||||
@ -74,6 +77,7 @@ struct Args {
|
|||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
enum Commands {
|
enum Commands {
|
||||||
PrintSchema,
|
PrintSchema,
|
||||||
@ -83,8 +87,11 @@ enum Commands {
|
|||||||
async fn main() -> Result<(), RouterError> {
|
async fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
// Get args
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let _served_model_name = args.served_model_name.clone();
|
||||||
|
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
|
served_model_name,
|
||||||
command,
|
command,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
@ -151,6 +158,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
served_model_name.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
@ -214,6 +222,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
served_model_name.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -51,6 +51,7 @@ impl Queue {
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
support_chunking: bool,
|
support_chunking: bool,
|
||||||
|
served_model_name: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
@ -65,6 +66,7 @@ impl Queue {
|
|||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
support_chunking,
|
support_chunking,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
|
served_model_name,
|
||||||
));
|
));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
@ -124,6 +126,7 @@ async fn queue_task(
|
|||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
support_chunking: bool,
|
support_chunking: bool,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
|
served_model_name: String,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(
|
let mut state = State::new(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
@ -139,7 +142,7 @@ async fn queue_task(
|
|||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => {
|
QueueCommand::Append(entry, span) => {
|
||||||
span.in_scope(|| state.append(*entry));
|
span.in_scope(|| state.append(*entry));
|
||||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).increment(1.0);
|
||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
@ -150,11 +153,11 @@ async fn queue_task(
|
|||||||
span,
|
span,
|
||||||
} => {
|
} => {
|
||||||
let next_batch = state
|
let next_batch = state
|
||||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget, served_model_name.clone())
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).set(state.entries.len() as f64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -235,6 +238,7 @@ impl State {
|
|||||||
max_size: Option<usize>,
|
max_size: Option<usize>,
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
|
served_model_name: String,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
tracing::debug!("No queue");
|
tracing::debug!("No queue");
|
||||||
@ -274,7 +278,7 @@ impl State {
|
|||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||||
tracing::debug!("Dropping entry");
|
tracing::debug!("Dropping entry");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -478,7 +482,7 @@ impl State {
|
|||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
metrics::histogram!("tgi_batch_next_size", "model_name" => served_model_name.clone()).record(batch.size as f64);
|
||||||
|
|
||||||
Some((batch_entries, batch, next_batch_span))
|
Some((batch_entries, batch, next_batch_span))
|
||||||
}
|
}
|
||||||
@ -606,21 +610,24 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_empty() {
|
async fn test_next_batch_empty() {
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
assert!(state.next_batch(None, None, 1, 1, served_model_name.clone()).await.is_none());
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1, served_model_name.clone()).await.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_min_size() {
|
async fn test_next_batch_min_size() {
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 2, 2, served_model_name.clone()).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -636,7 +643,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
assert!(state.next_batch(Some(2), None, 2, 2, served_model_name.clone()).await.is_none());
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
assert_eq!(state.next_id, 3);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
@ -646,13 +653,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_max_size() {
|
async fn test_next_batch_max_size() {
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2, served_model_name.clone()).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
@ -666,13 +674,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_token_budget() {
|
async fn test_next_batch_token_budget() {
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 1, 1, served_model_name.clone()).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -685,7 +694,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 3, 3, served_model_name.clone()).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -699,14 +708,16 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
@ -714,7 +725,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -747,7 +759,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -763,7 +776,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -788,7 +802,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(true, 1, false, None, 2, 16, false);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(true, 1, false, None, 2, 16, false, served_model_name.clone());
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -807,7 +822,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
@ -560,6 +560,10 @@ struct Args {
|
|||||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
|
||||||
|
/// Name under which the model is served. Defaults to `model_id` if not provided.
|
||||||
|
#[clap(long, env)]
|
||||||
|
served_model_name: Option<String>,
|
||||||
|
|
||||||
/// The actual revision of the model if you're referring to a model
|
/// The actual revision of the model if you're referring to a model
|
||||||
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -1802,7 +1806,7 @@ fn spawn_webserver(
|
|||||||
"--master-shard-uds-path".to_string(),
|
"--master-shard-uds-path".to_string(),
|
||||||
format!("{}-0", args.shard_uds_path),
|
format!("{}-0", args.shard_uds_path),
|
||||||
"--tokenizer-name".to_string(),
|
"--tokenizer-name".to_string(),
|
||||||
args.model_id,
|
args.model_id.clone(),
|
||||||
"--payload-limit".to_string(),
|
"--payload-limit".to_string(),
|
||||||
args.payload_limit.to_string(),
|
args.payload_limit.to_string(),
|
||||||
];
|
];
|
||||||
@ -1973,6 +1977,12 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let args: Args = Args::parse();
|
let args: Args = Args::parse();
|
||||||
|
|
||||||
|
let served_model_name = args.served_model_name
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| args.model_id.clone());
|
||||||
|
|
||||||
|
env::set_var("SERVED_MODEL_NAME", &served_model_name);
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
// Filter events with LOG_LEVEL
|
||||||
let varname = "LOG_LEVEL";
|
let varname = "LOG_LEVEL";
|
||||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||||
|
@ -100,6 +100,7 @@ impl Infer {
|
|||||||
pub(crate) async fn generate_stream<'a>(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<
|
) -> Result<
|
||||||
(
|
(
|
||||||
OwnedSemaphorePermit,
|
OwnedSemaphorePermit,
|
||||||
@ -121,7 +122,7 @@ impl Infer {
|
|||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
let mut local_request = request.clone();
|
let mut local_request = request.clone();
|
||||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
let valid_request = self.validation.validate(request, served_model_name.clone()).await.map_err(|err| {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
@ -165,7 +166,7 @@ impl Infer {
|
|||||||
local_request.inputs.push_str(&generated_text.text);
|
local_request.inputs.push_str(&generated_text.text);
|
||||||
all_generated_text = all_generated_text.or(Some(generated_text));
|
all_generated_text = all_generated_text.or(Some(generated_text));
|
||||||
|
|
||||||
let valid_request = match self.validation.validate(local_request.clone()).await {
|
let valid_request = match self.validation.validate(local_request.clone(), served_model_name.clone()).await {
|
||||||
Ok(valid_request) => valid_request,
|
Ok(valid_request) => valid_request,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
tracing::debug!("Failed to continue request: {err}");
|
tracing::debug!("Failed to continue request: {err}");
|
||||||
@ -245,11 +246,12 @@ impl Infer {
|
|||||||
pub(crate) async fn generate(
|
pub(crate) async fn generate(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<InferResponse, InferError> {
|
) -> Result<InferResponse, InferError> {
|
||||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
|
let (_permit, _input_length, stream) = self.generate_stream(request, served_model_name).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
let mut result_prefill = Vec::new();
|
let mut result_prefill = Vec::new();
|
||||||
@ -322,13 +324,14 @@ impl Infer {
|
|||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
best_of: usize,
|
best_of: usize,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
||||||
// validate best_of parameter separately
|
// validate best_of parameter separately
|
||||||
let best_of = self.validation.validate_best_of(best_of)?;
|
let best_of = self.validation.validate_best_of(best_of)?;
|
||||||
|
|
||||||
// create multiple generate requests
|
// create multiple generate requests
|
||||||
let mut infer_responses: Vec<InferResponse> =
|
let mut infer_responses: Vec<InferResponse> =
|
||||||
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
try_join_all((0..best_of).map(|_| self.generate(request.clone(), served_model_name.clone()))).await?;
|
||||||
|
|
||||||
// get the sequence with the highest log probability per token
|
// get the sequence with the highest log probability per token
|
||||||
let mut max_index = 0;
|
let mut max_index = 0;
|
||||||
|
@ -68,15 +68,16 @@ pub(crate) async fn sagemaker_compatibility(
|
|||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
compute_type: Extension<ComputeType>,
|
||||||
info: Extension<Info>,
|
info: Extension<Info>,
|
||||||
|
served_model_name: Extension<String>,
|
||||||
Json(req): Json<SagemakerRequest>,
|
Json(req): Json<SagemakerRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
match req {
|
match req {
|
||||||
SagemakerRequest::Generate(req) => {
|
SagemakerRequest::Generate(req) => {
|
||||||
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
|
compat_generate(default_return_full_text, infer, compute_type, served_model_name, Json(req)).await
|
||||||
}
|
}
|
||||||
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
|
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, served_model_name, Json(req)).await,
|
||||||
SagemakerRequest::Completion(req) => {
|
SagemakerRequest::Completion(req) => {
|
||||||
completions(infer, compute_type, info, Json(req)).await
|
completions(infer, compute_type, info, served_model_name, Json(req)).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -129,6 +129,7 @@ pub(crate) async fn compat_generate(
|
|||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
compute_type: Extension<ComputeType>,
|
||||||
|
served_model_name: Extension<String>,
|
||||||
Json(mut req): Json<CompatGenerateRequest>,
|
Json(mut req): Json<CompatGenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// default return_full_text given the pipeline_tag
|
// default return_full_text given the pipeline_tag
|
||||||
@ -138,11 +139,11 @@ pub(crate) async fn compat_generate(
|
|||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if req.stream {
|
if req.stream {
|
||||||
Ok(generate_stream(infer, compute_type, Json(req.into()))
|
Ok(generate_stream(infer, served_model_name.clone(), compute_type, Json(req.into()))
|
||||||
.await
|
.await
|
||||||
.into_response())
|
.into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
|
let (headers, Json(generation)) = generate(infer, served_model_name.clone(), compute_type, Json(req.into())).await?;
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(vec![generation])).into_response())
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
}
|
}
|
||||||
@ -196,9 +197,10 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
|||||||
)]
|
)]
|
||||||
async fn get_chat_tokenize(
|
async fn get_chat_tokenize(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(served_model_name): Extension<String>,
|
||||||
Json(chat): Json<ChatRequest>,
|
Json(chat): Json<ChatRequest>,
|
||||||
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count", "model_name" => served_model_name).increment(1);
|
||||||
|
|
||||||
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
||||||
let input = generate_request.inputs.clone();
|
let input = generate_request.inputs.clone();
|
||||||
@ -270,23 +272,25 @@ seed,
|
|||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
|
served_model_name: Extension<String>,
|
||||||
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let (headers, _, response) =
|
let (headers, _, response) =
|
||||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
|
generate_internal(infer, served_model_name, ComputeType(compute_type), Json(req), span).await?;
|
||||||
Ok((headers, response))
|
Ok((headers, response))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn generate_internal(
|
pub(crate) async fn generate_internal(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
|
served_model_name: Extension<String>,
|
||||||
ComputeType(compute_type): ComputeType,
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count", "model_name" => served_model_name.0.clone()).increment(1);
|
||||||
|
|
||||||
// Do not long ultra long inputs, like image payloads.
|
// Do not long ultra long inputs, like image payloads.
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@ -305,10 +309,10 @@ pub(crate) async fn generate_internal(
|
|||||||
// Inference
|
// Inference
|
||||||
let (response, best_of_responses) = match req.parameters.best_of {
|
let (response, best_of_responses) = match req.parameters.best_of {
|
||||||
Some(best_of) if best_of > 1 => {
|
Some(best_of) if best_of > 1 => {
|
||||||
let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
|
let (response, best_of_responses) = infer.generate_best_of(req, best_of, served_model_name.0.clone()).await?;
|
||||||
(response, Some(best_of_responses))
|
(response, Some(best_of_responses))
|
||||||
}
|
}
|
||||||
_ => (infer.generate(req).await?, None),
|
_ => (infer.generate(req, served_model_name.0.clone()).await?, None),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
@ -405,14 +409,14 @@ pub(crate) async fn generate_internal(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
metrics::counter!("tgi_request_success").increment(1);
|
metrics::counter!("tgi_request_success", "model_name" => served_model_name.0.clone()).increment(1);
|
||||||
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
|
metrics::histogram!("tgi_request_duration", "model_name" => served_model_name.0.clone()).record(total_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
|
metrics::histogram!("tgi_request_validation_duration", "model_name" => served_model_name.0.clone()).record(validation_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
|
metrics::histogram!("tgi_request_queue_duration", "model_name" => served_model_name.0.clone()).record(queue_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
|
metrics::histogram!("tgi_request_inference_duration", "model_name" => served_model_name.0.clone()).record(inference_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_mean_time_per_token_duration")
|
metrics::histogram!("tgi_request_mean_time_per_token_duration", "model_name" => served_model_name.0.clone())
|
||||||
.record(time_per_token.as_secs_f64());
|
.record(time_per_token.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_generated_tokens")
|
metrics::histogram!("tgi_request_generated_tokens", "model_name" => served_model_name.0.clone())
|
||||||
.record(response.generated_text.generated_tokens as f64);
|
.record(response.generated_text.generated_tokens as f64);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
@ -468,6 +472,7 @@ seed,
|
|||||||
)]
|
)]
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(served_model_name): Extension<String>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> (
|
) -> (
|
||||||
@ -476,7 +481,7 @@ async fn generate_stream(
|
|||||||
) {
|
) {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
generate_stream_internal(infer, served_model_name, compute_type, Json(req), span).await;
|
||||||
|
|
||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
@ -495,6 +500,7 @@ async fn generate_stream(
|
|||||||
|
|
||||||
async fn generate_stream_internal(
|
async fn generate_stream_internal(
|
||||||
infer: Infer,
|
infer: Infer,
|
||||||
|
served_model_name: String,
|
||||||
ComputeType(compute_type): ComputeType,
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -503,7 +509,7 @@ async fn generate_stream_internal(
|
|||||||
impl Stream<Item = Result<StreamResponse, InferError>>,
|
impl Stream<Item = Result<StreamResponse, InferError>>,
|
||||||
) {
|
) {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
|
||||||
|
|
||||||
tracing::debug!("Input: {}", req.inputs);
|
tracing::debug!("Input: {}", req.inputs);
|
||||||
|
|
||||||
@ -540,7 +546,7 @@ async fn generate_stream_internal(
|
|||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Err(err);
|
yield Err(err);
|
||||||
} else {
|
} else {
|
||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req, served_model_name.clone()).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, input_length, response_stream)) => {
|
Ok((_permit, input_length, response_stream)) => {
|
||||||
let mut index = 0;
|
let mut index = 0;
|
||||||
@ -605,13 +611,13 @@ async fn generate_stream_internal(
|
|||||||
span.record("seed", format!("{:?}", generated_text.seed));
|
span.record("seed", format!("{:?}", generated_text.seed));
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
metrics::counter!("tgi_request_success").increment(1);
|
metrics::counter!("tgi_request_success", "model_name" => served_model_name.clone()).increment(1);
|
||||||
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
|
metrics::histogram!("tgi_request_duration", "model_name" => served_model_name.clone()).record(total_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
|
metrics::histogram!("tgi_request_validation_duration", "model_name" => served_model_name.clone()).record(validation_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
|
metrics::histogram!("tgi_request_queue_duration", "model_name" => served_model_name.clone()).record(queue_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
|
metrics::histogram!("tgi_request_inference_duration", "model_name" => served_model_name.clone()).record(inference_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64());
|
metrics::histogram!("tgi_request_mean_time_per_token_duration", "model_name" => served_model_name.clone()).record(time_per_token.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64);
|
metrics::histogram!("tgi_request_generated_tokens", "model_name" => served_model_name.clone()).record(generated_text.generated_tokens as f64);
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
end_reached = true;
|
end_reached = true;
|
||||||
@ -704,10 +710,11 @@ pub(crate) async fn completions(
|
|||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
Extension(served_model_name): Extension<String>,
|
||||||
Json(req): Json<CompletionRequest>,
|
Json(req): Json<CompletionRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
|
||||||
|
|
||||||
let CompletionRequest {
|
let CompletionRequest {
|
||||||
model,
|
model,
|
||||||
@ -798,6 +805,7 @@ pub(crate) async fn completions(
|
|||||||
let infer_clone = infer.clone();
|
let infer_clone = infer.clone();
|
||||||
let compute_type_clone = compute_type.clone();
|
let compute_type_clone = compute_type.clone();
|
||||||
let span_clone = span.clone();
|
let span_clone = span.clone();
|
||||||
|
let served_model_name_clone = served_model_name.clone();
|
||||||
|
|
||||||
// Create a future for each generate_stream_internal call.
|
// Create a future for each generate_stream_internal call.
|
||||||
let generate_future = async move {
|
let generate_future = async move {
|
||||||
@ -807,6 +815,7 @@ pub(crate) async fn completions(
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let (headers, response_stream) = generate_stream_internal(
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
infer_clone.clone(),
|
infer_clone.clone(),
|
||||||
|
served_model_name_clone.clone(),
|
||||||
compute_type_clone.clone(),
|
compute_type_clone.clone(),
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
span_clone.clone(),
|
span_clone.clone(),
|
||||||
@ -975,11 +984,13 @@ pub(crate) async fn completions(
|
|||||||
let responses = FuturesUnordered::new();
|
let responses = FuturesUnordered::new();
|
||||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||||
let infer_clone = infer.clone();
|
let infer_clone = infer.clone();
|
||||||
|
let served_model_name_clone = served_model_name.clone();
|
||||||
let compute_type_clone = compute_type.clone();
|
let compute_type_clone = compute_type.clone();
|
||||||
let span_clone = span.clone();
|
let span_clone = span.clone();
|
||||||
let response_future = async move {
|
let response_future = async move {
|
||||||
let result = generate_internal(
|
let result = generate_internal(
|
||||||
Extension(infer_clone),
|
Extension(infer_clone),
|
||||||
|
Extension(served_model_name_clone),
|
||||||
compute_type_clone,
|
compute_type_clone,
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
span_clone,
|
span_clone,
|
||||||
@ -1230,10 +1241,11 @@ pub(crate) async fn chat_completions(
|
|||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
Extension(served_model_name): Extension<String>,
|
||||||
Json(chat): Json<ChatRequest>,
|
Json(chat): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
model,
|
model,
|
||||||
stream,
|
stream,
|
||||||
@ -1255,7 +1267,7 @@ pub(crate) async fn chat_completions(
|
|||||||
// switch on stream
|
// switch on stream
|
||||||
if stream {
|
if stream {
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
generate_stream_internal(infer, served_model_name, compute_type, Json(generate_request), span).await;
|
||||||
|
|
||||||
// regex to match any function name
|
// regex to match any function name
|
||||||
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||||
@ -1389,7 +1401,7 @@ pub(crate) async fn chat_completions(
|
|||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, input_length, Json(generation)) =
|
let (headers, input_length, Json(generation)) =
|
||||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
generate_internal(Extension(infer), Extension(served_model_name), compute_type, Json(generate_request), span).await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -1688,6 +1700,7 @@ pub async fn run(
|
|||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: usage_stats::UsageStatsLevel,
|
usage_stats_level: usage_stats::UsageStatsLevel,
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<(), WebServerError> {
|
) -> Result<(), WebServerError> {
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
@ -1963,6 +1976,7 @@ pub async fn run(
|
|||||||
compat_return_full_text,
|
compat_return_full_text,
|
||||||
allow_origin,
|
allow_origin,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
served_model_name,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@ -2024,6 +2038,7 @@ async fn start(
|
|||||||
compat_return_full_text: bool,
|
compat_return_full_text: bool,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<(), WebServerError> {
|
) -> Result<(), WebServerError> {
|
||||||
// Determine the server port based on the feature and environment variable.
|
// Determine the server port based on the feature and environment variable.
|
||||||
let port = if cfg!(feature = "google") {
|
let port = if cfg!(feature = "google") {
|
||||||
@ -2076,22 +2091,22 @@ async fn start(
|
|||||||
duration_buckets.push(value);
|
duration_buckets.push(value);
|
||||||
}
|
}
|
||||||
// Input Length buckets
|
// Input Length buckets
|
||||||
let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length"));
|
let input_length_matcher = Matcher::Full(format!("tgi_request_input_length{{model_name=\"{}\"}}", served_model_name));
|
||||||
let input_length_buckets: Vec<f64> = (0..100)
|
let input_length_buckets: Vec<f64> = (0..100)
|
||||||
.map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64)
|
.map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64)
|
||||||
.collect();
|
.collect();
|
||||||
// Generated tokens buckets
|
// Generated tokens buckets
|
||||||
let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
|
let generated_tokens_matcher = Matcher::Full(format!("tgi_request_generated_tokens{{model_name=\"{}\"}}", served_model_name));
|
||||||
let generated_tokens_buckets: Vec<f64> = (0..100)
|
let generated_tokens_buckets: Vec<f64> = (0..100)
|
||||||
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
|
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
|
||||||
.collect();
|
.collect();
|
||||||
// Input Length buckets
|
// Input Length buckets
|
||||||
let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens"));
|
let max_new_tokens_matcher = Matcher::Full(format!("tgi_request_max_new_tokens{{model_name=\"{}\"}}", served_model_name));
|
||||||
let max_new_tokens_buckets: Vec<f64> = (0..100)
|
let max_new_tokens_buckets: Vec<f64> = (0..100)
|
||||||
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
|
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
|
||||||
.collect();
|
.collect();
|
||||||
// Batch size buckets
|
// Batch size buckets
|
||||||
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
let batch_size_matcher = Matcher::Full(format!("ttgi_batch_next_size{{model_name=\"{}\"}}", served_model_name));
|
||||||
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
||||||
// Speculated tokens buckets
|
// Speculated tokens buckets
|
||||||
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||||
@ -2334,7 +2349,8 @@ async fn start(
|
|||||||
.route("/v1/completions", post(completions))
|
.route("/v1/completions", post(completions))
|
||||||
.route("/vertex", post(vertex_compatibility))
|
.route("/vertex", post(vertex_compatibility))
|
||||||
.route("/invocations", post(sagemaker_compatibility))
|
.route("/invocations", post(sagemaker_compatibility))
|
||||||
.route("/tokenize", post(tokenize));
|
.route("/tokenize", post(tokenize))
|
||||||
|
.layer(Extension(served_model_name));
|
||||||
|
|
||||||
if let Some(api_key) = api_key {
|
if let Some(api_key) = api_key {
|
||||||
let mut prefix = "Bearer ".to_string();
|
let mut prefix = "Bearer ".to_string();
|
||||||
|
@ -133,6 +133,7 @@ impl Validation {
|
|||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {
|
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let (encoding, inputs) = self
|
let (encoding, inputs) = self
|
||||||
@ -186,7 +187,7 @@ impl Validation {
|
|||||||
let ids = encoding.get_ids();
|
let ids = encoding.get_ids();
|
||||||
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
metrics::histogram!("tgi_request_input_length", "model_name" => served_model_name.clone()).record(input_length as f64);
|
||||||
Ok((
|
Ok((
|
||||||
inputs,
|
inputs,
|
||||||
Some(input_ids),
|
Some(input_ids),
|
||||||
@ -201,6 +202,7 @@ impl Validation {
|
|||||||
pub(crate) async fn validate(
|
pub(crate) async fn validate(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
|
served_model_name: String,
|
||||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
let GenerateParameters {
|
let GenerateParameters {
|
||||||
best_of,
|
best_of,
|
||||||
@ -332,6 +334,7 @@ impl Validation {
|
|||||||
request.add_special_tokens,
|
request.add_special_tokens,
|
||||||
truncate,
|
truncate,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
served_model_name.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
@ -405,7 +408,7 @@ impl Validation {
|
|||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64);
|
metrics::histogram!("tgi_request_max_new_tokens", "model_name" => served_model_name.clone()).record(max_new_tokens as f64);
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
@ -953,10 +956,10 @@ mod tests {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation
|
match validation
|
||||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens), served_model_name)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||||
@ -989,9 +992,10 @@ mod tests {
|
|||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation
|
match validation
|
||||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens), served_model_name)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||||
@ -1022,6 +1026,7 @@ mod tests {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
@ -1031,7 +1036,7 @@ mod tests {
|
|||||||
do_sample: false,
|
do_sample: false,
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
}, served_model_name)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::BestOfSampling) => (),
|
Err(ValidationError::BestOfSampling) => (),
|
||||||
@ -1062,6 +1067,7 @@ mod tests {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
@ -1071,7 +1077,7 @@ mod tests {
|
|||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
}, served_model_name.clone())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::TopP) => (),
|
Err(ValidationError::TopP) => (),
|
||||||
@ -1087,7 +1093,7 @@ mod tests {
|
|||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
}, served_model_name.clone())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
@ -1103,7 +1109,7 @@ mod tests {
|
|||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
}, served_model_name.clone())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||||
@ -1133,6 +1139,7 @@ mod tests {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
|
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
@ -1142,7 +1149,7 @@ mod tests {
|
|||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
}, served_model_name.clone())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::TopNTokens(4, 5)) => (),
|
Err(ValidationError::TopNTokens(4, 5)) => (),
|
||||||
@ -1158,7 +1165,7 @@ mod tests {
|
|||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
}, served_model_name.clone())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1171,7 +1178,7 @@ mod tests {
|
|||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
}, served_model_name.clone())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1184,7 +1191,7 @@ mod tests {
|
|||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
}, served_model_name.clone())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -69,11 +69,12 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||||||
)]
|
)]
|
||||||
pub(crate) async fn vertex_compatibility(
|
pub(crate) async fn vertex_compatibility(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(served_model_name): Extension<String>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Json(req): Json<VertexRequest>,
|
Json(req): Json<VertexRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
|
||||||
|
|
||||||
// check that theres at least one instance
|
// check that theres at least one instance
|
||||||
if req.instances.is_empty() {
|
if req.instances.is_empty() {
|
||||||
@ -111,12 +112,14 @@ pub(crate) async fn vertex_compatibility(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let infer_clone = infer.clone();
|
let infer_clone = infer.clone();
|
||||||
|
let served_model_name_clone = served_model_name.clone();
|
||||||
let compute_type_clone = compute_type.clone();
|
let compute_type_clone = compute_type.clone();
|
||||||
let span_clone = span.clone();
|
let span_clone = span.clone();
|
||||||
|
|
||||||
futures.push(async move {
|
futures.push(async move {
|
||||||
generate_internal(
|
generate_internal(
|
||||||
Extension(infer_clone),
|
Extension(infer_clone),
|
||||||
|
Extension(served_model_name_clone),
|
||||||
compute_type_clone,
|
compute_type_clone,
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
span_clone,
|
span_clone,
|
||||||
|
Loading…
Reference in New Issue
Block a user