diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 5a07acdcd..0a5a9de9b 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -19,6 +19,10 @@ struct Args { #[clap(long, env)] model_id: String, + /// Name under which the model is served. Defaults to `model_id` if not provided. + #[clap(long, env)] + served_model_name: Option, + /// Revision of the model. #[clap(default_value = "main", long, env)] revision: String, @@ -152,6 +156,10 @@ struct Args { async fn main() -> Result<(), RouterError> { let args = Args::parse(); + let served_model_name = args.served_model_name + .clone() + .unwrap_or_else(|| args.model_id.clone()); + logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output); let n_threads = match args.n_threads { @@ -264,6 +272,7 @@ async fn main() -> Result<(), RouterError> { args.max_client_batch_size, args.usage_stats, args.payload_limit, + served_model_name ) .await?; Ok(()) diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index cef225be1..e3be62d3f 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -45,6 +45,8 @@ struct Args { revision: Option, #[clap(long, env)] model_id: String, + #[clap(long, env)] + served_model_name: Option, #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(long, env)] @@ -227,6 +229,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { tokenizer_config_path, revision, model_id, + served_model_name, validation_workers, json_output, otlp_endpoint, @@ -239,6 +242,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { payload_limit, } = args; + let served_model_name = args.served_model_name + .clone() + .unwrap_or_else(|| args.model_id.clone()); + // Launch Tokio runtime text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); @@ -318,6 +325,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { max_client_batch_size, usage_stats, payload_limit, + served_model_name, ) .await?; Ok(()) diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index adca3d5d2..85c70a233 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -34,6 +34,7 @@ impl BackendV2 { requires_padding: bool, window_size: Option, speculate: u32, + served_model_name: String, ) -> Self { // Infer shared state let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string()); @@ -44,7 +45,7 @@ impl BackendV2 { _ => unreachable!(), }; - let queue = Queue::new(requires_padding, block_size, window_size, speculate); + let queue = Queue::new(requires_padding, block_size, window_size, speculate, served_model_name.clone()); let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic @@ -57,6 +58,7 @@ impl BackendV2 { max_batch_size, queue.clone(), batching_task_notifier.clone(), + served_model_name.clone(), )); Self { @@ -128,6 +130,7 @@ pub(crate) async fn batching_task( max_batch_size: Option, queue: Queue, notifier: Arc, + served_model_name: String, ) { // Infinite loop loop { @@ -146,7 +149,7 @@ pub(crate) async fn batching_task( ) .await { - let mut cached_batch = prefill(&mut client, batch, &mut entries) + let mut cached_batch = prefill(&mut client, batch, &mut entries, served_model_name.clone()) .instrument(span) .await; let mut waiting_tokens = 1; @@ -158,8 +161,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -180,10 +183,10 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone()) .increment(1); } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone()) .increment(1); } @@ -199,7 +202,7 @@ pub(crate) async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries, served_model_name.clone()) .instrument(span) .await; // Reset waiting counter @@ -225,13 +228,13 @@ pub(crate) async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries) + cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone()) .instrument(next_batch_span) .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size").set(0.0); - metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0); } } } @@ -241,36 +244,37 @@ async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, + served_model_name: String, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + filter_send_generations(generations, entries, served_model_name.clone()); // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill", "model_name" => served_model_name.clone()) .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone()) .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone()) .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") + metrics::histogram!("tgi_batch_inference_duration","method" => "prefill", "model_name" => served_model_name.clone()) .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1); next_batch } // If we have an error, we discard the whole batch Err(err) => { let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1); None } } @@ -281,33 +285,34 @@ async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, + served_model_name: String, ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + filter_send_generations(generations, entries, served_model_name.clone()); // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1); next_batch } // If we have an error, we discard the whole batch @@ -316,7 +321,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1); None } } @@ -358,7 +363,7 @@ async fn filter_batch( /// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// and filter entries #[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { +fn filter_send_generations(generations: Vec, entries: &mut IntMap, served_model_name: String) { generations.into_iter().for_each(|generation| { let id = generation.request_id; // Get entry @@ -372,9 +377,9 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped").increment(1); + metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1); }).unwrap_or(true); if stopped { entries.remove(&id).expect("ID not found in entries. This is a bug."); @@ -386,10 +391,11 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1); return Ok(true); } @@ -415,7 +421,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() diff --git a/backends/v2/src/lib.rs b/backends/v2/src/lib.rs index 85c36931c..f675fc35d 100644 --- a/backends/v2/src/lib.rs +++ b/backends/v2/src/lib.rs @@ -39,6 +39,7 @@ pub async fn connect_backend( max_batch_total_tokens: Option, max_waiting_tokens: usize, max_batch_size: Option, + served_model_name: String, ) -> Result<(BackendV2, BackendInfo), V2Error> { // Helper function let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { @@ -108,7 +109,7 @@ pub async fn connect_backend( model_dtype: shard_info.dtype.clone(), speculate: shard_info.speculate as usize, }; - + let backend = BackendV2::new( sharded_client, waiting_served_ratio, @@ -119,6 +120,7 @@ pub async fn connect_backend( shard_info.requires_padding, shard_info.window_size, shard_info.speculate, + served_model_name, ); tracing::info!("Using backend V3"); diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index f537690e4..3c0d6a72c 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -7,6 +7,9 @@ use thiserror::Error; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + #[clap(long, env)] + served_model_name: String, + #[command(subcommand)] command: Option, @@ -83,8 +86,11 @@ enum Commands { async fn main() -> Result<(), RouterError> { // Get args let args = Args::parse(); + let _served_model_name = args.served_model_name.clone(); + // Pattern match configuration let Args { + served_model_name, command, max_concurrent_requests, max_best_of, @@ -170,8 +176,9 @@ async fn main() -> Result<(), RouterError> { max_batch_total_tokens, max_waiting_tokens, max_batch_size, + served_model_name.clone(), ) - .await?; + .await?; // Run server server::run( @@ -198,6 +205,7 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, usage_stats, payload_limit, + served_model_name.clone(), ) .await?; Ok(()) diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index c9a9335dd..07255d362 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -43,10 +43,11 @@ impl Queue { block_size: u32, window_size: Option, speculate: u32, + served_model_name: String, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); - + // Launch background queue task tokio::spawn(queue_task( requires_padding, @@ -54,6 +55,7 @@ impl Queue { window_size, speculate, queue_receiver, + served_model_name, )); Self { queue_sender } @@ -104,6 +106,7 @@ async fn queue_task( window_size: Option, speculate: u32, mut receiver: mpsc::UnboundedReceiver, + served_model_name: String, ) { let mut state = State::new(requires_padding, block_size, window_size, speculate); @@ -111,7 +114,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::gauge!("tgi_queue_size").increment(1.0); + metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).increment(1.0); } QueueCommand::NextBatch { min_size, @@ -122,9 +125,9 @@ async fn queue_task( span, } => span.in_scope(|| { let next_batch = - state.next_batch(min_size, max_size, prefill_token_budget, token_budget); + state.next_batch(min_size, max_size, prefill_token_budget, token_budget, served_model_name.clone()); response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).set(state.entries.len() as f64); }), } } @@ -191,6 +194,7 @@ impl State { max_size: Option, prefill_token_budget: u32, token_budget: u32, + served_model_name: String, ) -> Option { if self.entries.is_empty() { tracing::debug!("No queue"); @@ -232,7 +236,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1); tracing::debug!("Dropping entry"); continue; } @@ -340,7 +344,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); + metrics::histogram!("tgi_batch_next_size", "model_name" => served_model_name.clone()).record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } @@ -466,21 +470,23 @@ mod tests { #[test] fn test_next_batch_empty() { + let served_model_name = "bigscience/blomm-560m".to_string(); let mut state = State::new(false, 1, None, 0); - assert!(state.next_batch(None, None, 1, 1).is_none()); - assert!(state.next_batch(Some(1), None, 1, 1).is_none()); + assert!(state.next_batch(None, None, 1, 1, served_model_name.clone()).is_none()); + assert!(state.next_batch(Some(1), None, 1, 1, served_model_name.clone()).is_none()); } #[test] fn test_next_batch_min_size() { + let served_model_name = "bigscience/blomm-560m".to_string(); let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2, served_model_name.clone()).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -496,7 +502,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), None, 2, 2).is_none()); + assert!(state.next_batch(Some(2), None, 2, 2, served_model_name.clone()).is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -506,13 +512,14 @@ mod tests { #[test] fn test_next_batch_max_size() { + let served_model_name = "bigscience/blomm-560m".to_string(); let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2, served_model_name.clone()).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert!(entries.get(&0).unwrap().batch_time.is_some()); @@ -526,13 +533,14 @@ mod tests { #[test] fn test_next_batch_token_budget() { + let served_model_name = "bigscience/blomm-560m".to_string(); let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1, served_model_name.clone()).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -545,7 +553,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3, served_model_name.clone()).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -559,14 +567,16 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, None, 0, served_model_name.clone()); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, None, 0, served_model_name.clone()); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -574,7 +584,8 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, None, 0, served_model_name.clone()); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -607,7 +618,8 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, None, 0, served_model_name.clone()); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -623,7 +635,9 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0); + let served_model_name = "bigscience/blomm-560m".to_string(); + + let queue = Queue::new(false, 1, None, 0, served_model_name.clone()); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -648,7 +662,9 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2); + let served_model_name = "bigscience/blomm-560m".to_string(); + + let queue = Queue::new(false, 1, None, 2, served_model_name.clone()); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -667,7 +683,9 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0); + let served_model_name = "bigscience/blomm-560m".to_string(); + + let queue = Queue::new(false, 1, None, 0, served_model_name.clone()); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 98e8d76f0..e7095e76b 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -34,6 +34,7 @@ impl BackendV3 { max_waiting_tokens: usize, max_batch_size: Option, shard_info: InfoResponse, + served_model_name: String, ) -> Self { if shard_info.support_chunking { tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored."); @@ -49,6 +50,7 @@ impl BackendV3 { shard_info.speculate, max_batch_total_tokens, shard_info.support_chunking, + served_model_name.clone(), ); let batching_task_notifier = Arc::new(Notify::new()); @@ -63,6 +65,7 @@ impl BackendV3 { shard_info.support_chunking, queue.clone(), batching_task_notifier.clone(), + served_model_name.clone(), )); Self { @@ -136,6 +139,7 @@ pub(crate) async fn batching_task( support_chunking: bool, queue: Queue, notifier: Arc, + served_model_name: String, ) { // Infinite loop loop { @@ -154,7 +158,7 @@ pub(crate) async fn batching_task( ) .await { - let mut cached_batch = prefill(&mut client, batch, None, &mut entries) + let mut cached_batch = prefill(&mut client, batch, None, &mut entries, served_model_name.clone()) .instrument(span) .await; let mut waiting_tokens = 1; @@ -167,8 +171,8 @@ pub(crate) async fn batching_task( let batch_max_tokens = batch.max_tokens; let current_tokens = batch.current_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); @@ -207,13 +211,13 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone()) .increment(1); } else { let counter = if support_chunking { - metrics::counter!("tgi_batch_concat", "reason" => "chunking") + metrics::counter!("tgi_batch_concat", "reason" => "chunking", "model_name" => served_model_name.clone()) } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone()) }; counter.increment(1); } @@ -226,7 +230,7 @@ pub(crate) async fn batching_task( entries.extend(new_entries); // Generate one token for both the cached batch and the new batch let new_cached_batch = - prefill(&mut client, new_batch, cached_batch, &mut entries) + prefill(&mut client, new_batch, cached_batch, &mut entries, served_model_name.clone()) .instrument(span) .await; if new_cached_batch.is_none() { @@ -250,7 +254,7 @@ pub(crate) async fn batching_task( // Generate one token for this new batch to have the attention past in cache let new_cached_batch = - prefill(&mut client, new_batch, None, &mut new_entries) + prefill(&mut client, new_batch, None, &mut new_entries, served_model_name.clone()) .instrument(span) .await; if new_cached_batch.is_some() { @@ -282,13 +286,13 @@ pub(crate) async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries) + cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone()) .instrument(next_batch_span) .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size").set(0.0); - metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0); } } } @@ -299,40 +303,41 @@ async fn prefill( batch: Batch, cached_batch: Option, entries: &mut IntMap, + served_model_name: String, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1); match client.prefill(batch, cached_batch).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + filter_send_generations(generations, entries, served_model_name.clone()); // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") + metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill", "model_name" => served_model_name.clone()) .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone()) .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone()) .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill", "model_name" => served_model_name.clone()) .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1); next_batch } // If we have an error, we discard the whole batch Err(err) => { let _ = client.clear_cache(Some(batch_id)).await; - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + send_errors(err, entries, served_model_name.clone()); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1); None } } @@ -343,33 +348,34 @@ async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, + served_model_name: String, ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + filter_send_generations(generations, entries, served_model_name.clone()); // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone()) .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1); next_batch } // If we have an error, we discard the whole batch @@ -377,8 +383,8 @@ async fn decode( for id in batch_ids { let _ = client.clear_cache(Some(id)).await; } - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + send_errors(err, entries, served_model_name.clone()); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1); None } } @@ -420,7 +426,7 @@ async fn filter_batch( /// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// and filter entries #[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { +fn filter_send_generations(generations: Vec, entries: &mut IntMap, served_model_name: String) { generations.into_iter().for_each(|generation| { let id = generation.request_id; // Get entry @@ -434,9 +440,9 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped").increment(1); + metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1); }).unwrap_or(true); if stopped { entries.remove(&id).expect("ID not found in entries. This is a bug."); @@ -448,10 +454,11 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1); return Ok(true); } @@ -477,7 +484,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -537,12 +544,12 @@ fn send_responses( /// Send errors to Infer for all `entries` #[instrument(skip_all)] -fn send_errors(error: ClientError, entries: &mut IntMap) { +fn send_errors(error: ClientError, entries: &mut IntMap, served_model_name: String) { entries.drain().for_each(|(_, entry)| { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); + metrics::counter!("tgi_request_failure", "err" => "generation", "model_name" => served_model_name.clone()).increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 09137853f..303564337 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -54,6 +54,7 @@ pub async fn connect_backend( max_batch_total_tokens: Option, max_waiting_tokens: usize, max_batch_size: Option, + served_model_name: String, ) -> Result<(BackendV3, BackendInfo), V3Error> { // Helper function let check_max_batch_total_tokens = |( @@ -161,6 +162,7 @@ pub async fn connect_backend( max_waiting_tokens, max_batch_size, shard_info, + served_model_name, ); tracing::info!("Using backend V3"); diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 52e41b55a..e371b2acb 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -7,6 +7,9 @@ use thiserror::Error; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + #[clap(long, env)] + served_model_name: String, + #[command(subcommand)] command: Option, @@ -74,6 +77,7 @@ struct Args { payload_limit: usize, } + #[derive(Debug, Subcommand)] enum Commands { PrintSchema, @@ -83,8 +87,11 @@ enum Commands { async fn main() -> Result<(), RouterError> { // Get args let args = Args::parse(); + let _served_model_name = args.served_model_name.clone(); + // Pattern match configuration let Args { + served_model_name, command, max_concurrent_requests, max_best_of, @@ -151,6 +158,7 @@ async fn main() -> Result<(), RouterError> { max_batch_total_tokens, max_waiting_tokens, max_batch_size, + served_model_name.clone(), ) .await?; @@ -214,6 +222,7 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, usage_stats, payload_limit, + served_model_name.clone(), ) .await?; Ok(()) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 249eebf76..453d06f99 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -51,6 +51,7 @@ impl Queue { speculate: u32, max_batch_total_tokens: u32, support_chunking: bool, + served_model_name: String, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -65,6 +66,7 @@ impl Queue { max_batch_total_tokens, support_chunking, queue_receiver, + served_model_name, )); Self { queue_sender } @@ -124,6 +126,7 @@ async fn queue_task( max_batch_total_tokens: u32, support_chunking: bool, mut receiver: mpsc::UnboundedReceiver, + served_model_name: String, ) { let mut state = State::new( requires_padding, @@ -139,7 +142,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::gauge!("tgi_queue_size").increment(1.0); + metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).increment(1.0); } QueueCommand::NextBatch { min_size, @@ -150,11 +153,11 @@ async fn queue_task( span, } => { let next_batch = state - .next_batch(min_size, max_size, prefill_token_budget, token_budget) + .next_batch(min_size, max_size, prefill_token_budget, token_budget, served_model_name.clone()) .instrument(span) .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).set(state.entries.len() as f64); } } } @@ -235,6 +238,7 @@ impl State { max_size: Option, prefill_token_budget: u32, token_budget: u32, + served_model_name: String, ) -> Option { if self.entries.is_empty() { tracing::debug!("No queue"); @@ -274,7 +278,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1); tracing::debug!("Dropping entry"); continue; } @@ -478,7 +482,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); + metrics::histogram!("tgi_batch_next_size", "model_name" => served_model_name.clone()).record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } @@ -606,21 +610,24 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { + let served_model_name = "bigscience/blomm-560m".to_string(); let mut state = State::new(false, 1, false, None, 0, 16, false); - assert!(state.next_batch(None, None, 1, 1).await.is_none()); - assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); + assert!(state.next_batch(None, None, 1, 1, served_model_name.clone()).await.is_none()); + assert!(state.next_batch(Some(1), None, 1, 1, served_model_name.clone()).await.is_none()); } #[tokio::test] async fn test_next_batch_min_size() { + let served_model_name = "bigscience/blomm-560m".to_string(); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2, served_model_name.clone()).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -636,7 +643,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), None, 2, 2).await.is_none()); + assert!(state.next_batch(Some(2), None, 2, 2, served_model_name.clone()).await.is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -646,13 +653,14 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { + let served_model_name = "bigscience/blomm-560m".to_string(); let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap(); + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2, served_model_name.clone()).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert!(entries.get(&0).unwrap().batch_time.is_some()); @@ -666,13 +674,14 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { + let served_model_name = "bigscience/blomm-560m".to_string(); let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1, served_model_name.clone()).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -685,7 +694,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3, served_model_name.clone()).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -699,14 +708,16 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, false, None, 0, 16, false); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone()); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, false, None, 0, 16, false); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone()); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -714,7 +725,8 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, false, None, 0, 16, false); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone()); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -747,7 +759,8 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, false, None, 0, 16, false); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone()); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -763,7 +776,8 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, false, None, 0, 16, false); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone()); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -788,7 +802,8 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(true, 1, false, None, 2, 16, false); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(true, 1, false, None, 2, 16, false, served_model_name.clone()); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -807,7 +822,8 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, false, None, 0, 16, false); + let served_model_name = "bigscience/blomm-560m".to_string(); + let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone()); let (entry, _) = default_entry(); queue.append(entry); diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fbbe8a2d7..7afe79e70 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -560,6 +560,10 @@ struct Args { #[clap(default_value = "bigscience/bloom-560m", long, env)] model_id: String, + /// Name under which the model is served. Defaults to `model_id` if not provided. + #[clap(long, env)] + served_model_name: Option, + /// The actual revision of the model if you're referring to a model /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. #[clap(long, env)] @@ -1802,7 +1806,7 @@ fn spawn_webserver( "--master-shard-uds-path".to_string(), format!("{}-0", args.shard_uds_path), "--tokenizer-name".to_string(), - args.model_id, + args.model_id.clone(), "--payload-limit".to_string(), args.payload_limit.to_string(), ]; @@ -1973,6 +1977,12 @@ fn main() -> Result<(), LauncherError> { // Pattern match configuration let args: Args = Args::parse(); + let served_model_name = args.served_model_name + .clone() + .unwrap_or_else(|| args.model_id.clone()); + + env::set_var("SERVED_MODEL_NAME", &served_model_name); + // Filter events with LOG_LEVEL let varname = "LOG_LEVEL"; let env_filter = if let Ok(log_level) = std::env::var(varname) { diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 7eb8a41be..5a06966dd 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -100,6 +100,7 @@ impl Infer { pub(crate) async fn generate_stream<'a>( &'a self, request: GenerateRequest, + served_model_name: String, ) -> Result< ( OwnedSemaphorePermit, @@ -121,7 +122,7 @@ impl Infer { // Validate request let mut local_request = request.clone(); - let valid_request = self.validation.validate(request).await.map_err(|err| { + let valid_request = self.validation.validate(request, served_model_name.clone()).await.map_err(|err| { metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); err @@ -165,7 +166,7 @@ impl Infer { local_request.inputs.push_str(&generated_text.text); all_generated_text = all_generated_text.or(Some(generated_text)); - let valid_request = match self.validation.validate(local_request.clone()).await { + let valid_request = match self.validation.validate(local_request.clone(), served_model_name.clone()).await { Ok(valid_request) => valid_request, Err(err) => { tracing::debug!("Failed to continue request: {err}"); @@ -245,11 +246,12 @@ impl Infer { pub(crate) async fn generate( &self, request: GenerateRequest, + served_model_name: String, ) -> Result { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, stream) = self.generate_stream(request).await?; + let (_permit, _input_length, stream) = self.generate_stream(request, served_model_name).await?; // Return values let mut result_prefill = Vec::new(); @@ -322,13 +324,14 @@ impl Infer { &self, request: GenerateRequest, best_of: usize, + served_model_name: String, ) -> Result<(InferResponse, Vec), InferError> { // validate best_of parameter separately let best_of = self.validation.validate_best_of(best_of)?; // create multiple generate requests let mut infer_responses: Vec = - try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; + try_join_all((0..best_of).map(|_| self.generate(request.clone(), served_model_name.clone()))).await?; // get the sequence with the highest log probability per token let mut max_index = 0; diff --git a/router/src/sagemaker.rs b/router/src/sagemaker.rs index 750ef222b..f2eb6c169 100644 --- a/router/src/sagemaker.rs +++ b/router/src/sagemaker.rs @@ -68,15 +68,16 @@ pub(crate) async fn sagemaker_compatibility( infer: Extension, compute_type: Extension, info: Extension, + served_model_name: Extension, Json(req): Json, ) -> Result)> { match req { SagemakerRequest::Generate(req) => { - compat_generate(default_return_full_text, infer, compute_type, Json(req)).await + compat_generate(default_return_full_text, infer, compute_type, served_model_name, Json(req)).await } - SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await, + SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, served_model_name, Json(req)).await, SagemakerRequest::Completion(req) => { - completions(infer, compute_type, info, Json(req)).await + completions(infer, compute_type, info, served_model_name, Json(req)).await } } } diff --git a/router/src/server.rs b/router/src/server.rs index e9aa4612b..09a77c81e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -129,6 +129,7 @@ pub(crate) async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, compute_type: Extension, + served_model_name: Extension, Json(mut req): Json, ) -> Result)> { // default return_full_text given the pipeline_tag @@ -138,11 +139,11 @@ pub(crate) async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer, compute_type, Json(req.into())) + Ok(generate_stream(infer, served_model_name.clone(), compute_type, Json(req.into())) .await .into_response()) } else { - let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?; + let (headers, Json(generation)) = generate(infer, served_model_name.clone(), compute_type, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation])).into_response()) } @@ -196,9 +197,10 @@ async fn openai_get_model_info(info: Extension) -> Json { )] async fn get_chat_tokenize( Extension(infer): Extension, + Extension(served_model_name): Extension, Json(chat): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { - metrics::counter!("tgi_request_count").increment(1); + metrics::counter!("tgi_request_count", "model_name" => served_model_name).increment(1); let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0; let input = generate_request.inputs.clone(); @@ -270,23 +272,25 @@ seed, )] async fn generate( infer: Extension, + served_model_name: Extension, Extension(ComputeType(compute_type)): Extension, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); let (headers, _, response) = - generate_internal(infer, ComputeType(compute_type), Json(req), span).await?; + generate_internal(infer, served_model_name, ComputeType(compute_type), Json(req), span).await?; Ok((headers, response)) } pub(crate) async fn generate_internal( infer: Extension, + served_model_name: Extension, ComputeType(compute_type): ComputeType, Json(req): Json, span: tracing::Span, ) -> Result<(HeaderMap, u32, Json), (StatusCode, Json)> { let start_time = Instant::now(); - metrics::counter!("tgi_request_count").increment(1); + metrics::counter!("tgi_request_count", "model_name" => served_model_name.0.clone()).increment(1); // Do not long ultra long inputs, like image payloads. tracing::debug!( @@ -305,10 +309,10 @@ pub(crate) async fn generate_internal( // Inference let (response, best_of_responses) = match req.parameters.best_of { Some(best_of) if best_of > 1 => { - let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?; + let (response, best_of_responses) = infer.generate_best_of(req, best_of, served_model_name.0.clone()).await?; (response, Some(best_of_responses)) } - _ => (infer.generate(req).await?, None), + _ => (infer.generate(req, served_model_name.0.clone()).await?, None), }; // Token details @@ -405,14 +409,14 @@ pub(crate) async fn generate_internal( ); // Metrics - metrics::counter!("tgi_request_success").increment(1); - metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); - metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); - metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); - metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); - metrics::histogram!("tgi_request_mean_time_per_token_duration") + metrics::counter!("tgi_request_success", "model_name" => served_model_name.0.clone()).increment(1); + metrics::histogram!("tgi_request_duration", "model_name" => served_model_name.0.clone()).record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration", "model_name" => served_model_name.0.clone()).record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration", "model_name" => served_model_name.0.clone()).record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration", "model_name" => served_model_name.0.clone()).record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration", "model_name" => served_model_name.0.clone()) .record(time_per_token.as_secs_f64()); - metrics::histogram!("tgi_request_generated_tokens") + metrics::histogram!("tgi_request_generated_tokens", "model_name" => served_model_name.0.clone()) .record(response.generated_text.generated_tokens as f64); // Send response @@ -468,6 +472,7 @@ seed, )] async fn generate_stream( Extension(infer): Extension, + Extension(served_model_name): Extension, Extension(compute_type): Extension, Json(req): Json, ) -> ( @@ -476,7 +481,7 @@ async fn generate_stream( ) { let span = tracing::Span::current(); let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(req), span).await; + generate_stream_internal(infer, served_model_name, compute_type, Json(req), span).await; let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); @@ -495,6 +500,7 @@ async fn generate_stream( async fn generate_stream_internal( infer: Infer, + served_model_name: String, ComputeType(compute_type): ComputeType, Json(req): Json, span: tracing::Span, @@ -503,7 +509,7 @@ async fn generate_stream_internal( impl Stream>, ) { let start_time = Instant::now(); - metrics::counter!("tgi_request_count").increment(1); + metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1); tracing::debug!("Input: {}", req.inputs); @@ -540,7 +546,7 @@ async fn generate_stream_internal( tracing::error!("{err}"); yield Err(err); } else { - match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { + match infer.generate_stream(req, served_model_name.clone()).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives Ok((_permit, input_length, response_stream)) => { let mut index = 0; @@ -605,13 +611,13 @@ async fn generate_stream_internal( span.record("seed", format!("{:?}", generated_text.seed)); // Metrics - metrics::counter!("tgi_request_success").increment(1); - metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); - metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); - metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); - metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); - metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64()); - metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64); + metrics::counter!("tgi_request_success", "model_name" => served_model_name.clone()).increment(1); + metrics::histogram!("tgi_request_duration", "model_name" => served_model_name.clone()).record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration", "model_name" => served_model_name.clone()).record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration", "model_name" => served_model_name.clone()).record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration", "model_name" => served_model_name.clone()).record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration", "model_name" => served_model_name.clone()).record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens", "model_name" => served_model_name.clone()).record(generated_text.generated_tokens as f64); // StreamResponse end_reached = true; @@ -704,10 +710,11 @@ pub(crate) async fn completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, + Extension(served_model_name): Extension, Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::counter!("tgi_request_count").increment(1); + metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1); let CompletionRequest { model, @@ -798,6 +805,7 @@ pub(crate) async fn completions( let infer_clone = infer.clone(); let compute_type_clone = compute_type.clone(); let span_clone = span.clone(); + let served_model_name_clone = served_model_name.clone(); // Create a future for each generate_stream_internal call. let generate_future = async move { @@ -807,6 +815,7 @@ pub(crate) async fn completions( tokio::spawn(async move { let (headers, response_stream) = generate_stream_internal( infer_clone.clone(), + served_model_name_clone.clone(), compute_type_clone.clone(), Json(generate_request), span_clone.clone(), @@ -975,11 +984,13 @@ pub(crate) async fn completions( let responses = FuturesUnordered::new(); for (index, generate_request) in generate_requests.into_iter().enumerate() { let infer_clone = infer.clone(); + let served_model_name_clone = served_model_name.clone(); let compute_type_clone = compute_type.clone(); let span_clone = span.clone(); let response_future = async move { let result = generate_internal( Extension(infer_clone), + Extension(served_model_name_clone), compute_type_clone, Json(generate_request), span_clone, @@ -1230,10 +1241,11 @@ pub(crate) async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, + Extension(served_model_name): Extension, Json(chat): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::counter!("tgi_request_count").increment(1); + metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1); let ChatRequest { model, stream, @@ -1255,7 +1267,7 @@ pub(crate) async fn chat_completions( // switch on stream if stream { let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(generate_request), span).await; + generate_stream_internal(infer, served_model_name, compute_type, Json(generate_request), span).await; // regex to match any function name let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) { @@ -1389,7 +1401,7 @@ pub(crate) async fn chat_completions( Ok((headers, sse).into_response()) } else { let (headers, input_length, Json(generation)) = - generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; + generate_internal(Extension(infer), Extension(served_model_name), compute_type, Json(generate_request), span).await?; let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -1688,6 +1700,7 @@ pub async fn run( max_client_batch_size: usize, usage_stats_level: usage_stats::UsageStatsLevel, payload_limit: usize, + served_model_name: String, ) -> Result<(), WebServerError> { // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue @@ -1963,6 +1976,7 @@ pub async fn run( compat_return_full_text, allow_origin, payload_limit, + served_model_name, ) .await; @@ -2024,6 +2038,7 @@ async fn start( compat_return_full_text: bool, allow_origin: Option, payload_limit: usize, + served_model_name: String, ) -> Result<(), WebServerError> { // Determine the server port based on the feature and environment variable. let port = if cfg!(feature = "google") { @@ -2076,22 +2091,22 @@ async fn start( duration_buckets.push(value); } // Input Length buckets - let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length")); + let input_length_matcher = Matcher::Full(format!("tgi_request_input_length{{model_name=\"{}\"}}", served_model_name)); let input_length_buckets: Vec = (0..100) .map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64) .collect(); // Generated tokens buckets - let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens")); + let generated_tokens_matcher = Matcher::Full(format!("tgi_request_generated_tokens{{model_name=\"{}\"}}", served_model_name)); let generated_tokens_buckets: Vec = (0..100) .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64) .collect(); // Input Length buckets - let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens")); + let max_new_tokens_matcher = Matcher::Full(format!("tgi_request_max_new_tokens{{model_name=\"{}\"}}", served_model_name)); let max_new_tokens_buckets: Vec = (0..100) .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64) .collect(); // Batch size buckets - let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); + let batch_size_matcher = Matcher::Full(format!("ttgi_batch_next_size{{model_name=\"{}\"}}", served_model_name)); let batch_size_buckets: Vec = (0..1024).map(|x| (x + 1) as f64).collect(); // Speculated tokens buckets // let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); @@ -2334,7 +2349,8 @@ async fn start( .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) .route("/invocations", post(sagemaker_compatibility)) - .route("/tokenize", post(tokenize)); + .route("/tokenize", post(tokenize)) + .layer(Extension(served_model_name)); if let Some(api_key) = api_key { let mut prefix = "Bearer ".to_string(); diff --git a/router/src/validation.rs b/router/src/validation.rs index 320e7f03f..3aebf7314 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,7 +1,7 @@ use crate::config::Config; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{ - GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, + GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, TokenizerTrait, }; use crate::{PyTokenizer, Tokenizer}; @@ -133,6 +133,7 @@ impl Validation { add_special_tokens: bool, truncate: Option, max_new_tokens: Option, + served_model_name: String, ) -> Result<(Vec, Option>, usize, u32, u32), ValidationError> { // If we have a fast tokenizer let (encoding, inputs) = self @@ -186,7 +187,7 @@ impl Validation { let ids = encoding.get_ids(); let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); - metrics::histogram!("tgi_request_input_length").record(input_length as f64); + metrics::histogram!("tgi_request_input_length", "model_name" => served_model_name.clone()).record(input_length as f64); Ok(( inputs, Some(input_ids), @@ -201,6 +202,7 @@ impl Validation { pub(crate) async fn validate( &self, request: GenerateRequest, + served_model_name: String, ) -> Result { let GenerateParameters { best_of, @@ -332,6 +334,7 @@ impl Validation { request.add_special_tokens, truncate, max_new_tokens, + served_model_name.clone(), ) .await?; @@ -405,7 +408,7 @@ impl Validation { ignore_eos_token: false, }; - metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64); + metrics::histogram!("tgi_request_max_new_tokens", "model_name" => served_model_name.clone()).record(max_new_tokens as f64); Ok(ValidGenerateRequest { inputs, @@ -953,10 +956,10 @@ mod tests { max_total_tokens, disable_grammar_support, ); - + let served_model_name = "bigscience/blomm-560m".to_string(); let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) + .validate_input("Hello".to_string(), true, None, Some(max_new_tokens), served_model_name) .await { Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), @@ -989,9 +992,10 @@ mod tests { disable_grammar_support, ); + let served_model_name = "bigscience/blomm-560m".to_string(); let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) + .validate_input("Hello".to_string(), true, None, Some(max_new_tokens), served_model_name) .await { Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), @@ -1022,6 +1026,7 @@ mod tests { max_total_tokens, disable_grammar_support, ); + let served_model_name = "bigscience/blomm-560m".to_string(); match validation .validate(GenerateRequest { inputs: "Hello".to_string(), @@ -1031,7 +1036,7 @@ mod tests { do_sample: false, ..default_parameters() }, - }) + }, served_model_name) .await { Err(ValidationError::BestOfSampling) => (), @@ -1062,6 +1067,7 @@ mod tests { max_total_tokens, disable_grammar_support, ); + let served_model_name = "bigscience/blomm-560m".to_string(); match validation .validate(GenerateRequest { inputs: "Hello".to_string(), @@ -1071,7 +1077,7 @@ mod tests { max_new_tokens: Some(5), ..default_parameters() }, - }) + }, served_model_name.clone()) .await { Err(ValidationError::TopP) => (), @@ -1087,7 +1093,7 @@ mod tests { max_new_tokens: Some(5), ..default_parameters() }, - }) + }, served_model_name.clone()) .await { Ok(_) => (), @@ -1103,7 +1109,7 @@ mod tests { max_new_tokens: Some(5), ..default_parameters() }, - }) + }, served_model_name.clone()) .await .unwrap(); // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. @@ -1133,6 +1139,7 @@ mod tests { max_total_tokens, disable_grammar_support, ); + let served_model_name = "bigscience/blomm-560m".to_string(); match validation .validate(GenerateRequest { inputs: "Hello".to_string(), @@ -1142,7 +1149,7 @@ mod tests { max_new_tokens: Some(5), ..default_parameters() }, - }) + }, served_model_name.clone()) .await { Err(ValidationError::TopNTokens(4, 5)) => (), @@ -1158,7 +1165,7 @@ mod tests { max_new_tokens: Some(5), ..default_parameters() }, - }) + }, served_model_name.clone()) .await .unwrap(); @@ -1171,7 +1178,7 @@ mod tests { max_new_tokens: Some(5), ..default_parameters() }, - }) + }, served_model_name.clone()) .await .unwrap(); @@ -1184,7 +1191,7 @@ mod tests { max_new_tokens: Some(5), ..default_parameters() }, - }) + }, served_model_name.clone()) .await .unwrap(); diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 38695532c..9847e7d24 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -69,11 +69,12 @@ example = json ! ({"error": "Incomplete generation"})), )] pub(crate) async fn vertex_compatibility( Extension(infer): Extension, + Extension(served_model_name): Extension, Extension(compute_type): Extension, Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::counter!("tgi_request_count").increment(1); + metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1); // check that theres at least one instance if req.instances.is_empty() { @@ -111,12 +112,14 @@ pub(crate) async fn vertex_compatibility( }; let infer_clone = infer.clone(); + let served_model_name_clone = served_model_name.clone(); let compute_type_clone = compute_type.clone(); let span_clone = span.clone(); futures.push(async move { generate_internal( Extension(infer_clone), + Extension(served_model_name_clone), compute_type_clone, Json(generate_request), span_clone,