Added model name label to metrics and added an optional argument --served-model-name

This commit is contained in:
“yashaswipiplani” 2025-02-27 08:35:33 +00:00
parent 5eec3a8bb6
commit 380e73dba9
16 changed files with 292 additions and 167 deletions

View File

@ -19,6 +19,10 @@ struct Args {
#[clap(long, env)]
model_id: String,
/// Name under which the model is served. Defaults to `model_id` if not provided.
#[clap(long, env)]
served_model_name: Option<String>,
/// Revision of the model.
#[clap(default_value = "main", long, env)]
revision: String,
@ -152,6 +156,10 @@ struct Args {
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
let served_model_name = args.served_model_name
.clone()
.unwrap_or_else(|| args.model_id.clone());
logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
let n_threads = match args.n_threads {
@ -264,6 +272,7 @@ async fn main() -> Result<(), RouterError> {
args.max_client_batch_size,
args.usage_stats,
args.payload_limit,
served_model_name
)
.await?;
Ok(())

View File

@ -45,6 +45,8 @@ struct Args {
revision: Option<String>,
#[clap(long, env)]
model_id: String,
#[clap(long, env)]
served_model_name: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -227,6 +229,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
tokenizer_config_path,
revision,
model_id,
served_model_name,
validation_workers,
json_output,
otlp_endpoint,
@ -239,6 +242,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
payload_limit,
} = args;
let served_model_name = args.served_model_name
.clone()
.unwrap_or_else(|| args.model_id.clone());
// Launch Tokio runtime
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
@ -318,6 +325,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_client_batch_size,
usage_stats,
payload_limit,
served_model_name,
)
.await?;
Ok(())

View File

@ -34,6 +34,7 @@ impl BackendV2 {
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
served_model_name: String,
) -> Self {
// Infer shared state
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
@ -44,7 +45,7 @@ impl BackendV2 {
_ => unreachable!(),
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let queue = Queue::new(requires_padding, block_size, window_size, speculate, served_model_name.clone());
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
@ -57,6 +58,7 @@ impl BackendV2 {
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
served_model_name.clone(),
));
Self {
@ -128,6 +130,7 @@ pub(crate) async fn batching_task(
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
served_model_name: String,
) {
// Infinite loop
loop {
@ -146,7 +149,7 @@ pub(crate) async fn batching_task(
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
let mut cached_batch = prefill(&mut client, batch, &mut entries, served_model_name.clone())
.instrument(span)
.await;
let mut waiting_tokens = 1;
@ -158,8 +161,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -180,10 +183,10 @@ pub(crate) async fn batching_task(
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone())
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone())
.increment(1);
}
@ -199,7 +202,7 @@ pub(crate) async fn batching_task(
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries, served_model_name.clone())
.instrument(span)
.await;
// Reset waiting counter
@ -225,13 +228,13 @@ pub(crate) async fn batching_task(
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone())
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0);
}
}
}
@ -241,36 +244,37 @@ async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
served_model_name: String,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
filter_send_generations(generations, entries, served_model_name.clone());
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill", "model_name" => served_model_name.clone())
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone())
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone())
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill", "model_name" => served_model_name.clone())
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
None
}
}
@ -281,33 +285,34 @@ async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
served_model_name: String,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
filter_send_generations(generations, entries, served_model_name.clone());
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
next_batch
}
// If we have an error, we discard the whole batch
@ -316,7 +321,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
None
}
}
@ -358,7 +363,7 @@ async fn filter_batch(
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
@ -372,9 +377,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).inspect_err(|_err| {
let stopped = send_responses(generation, entry, served_model_name.clone()).inspect_err(|_err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
@ -386,10 +391,11 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
fn send_responses(
generation: Generation,
entry: &Entry,
served_model_name: String,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
return Ok(true);
}
@ -415,7 +421,7 @@ fn send_responses(
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()

View File

@ -39,6 +39,7 @@ pub async fn connect_backend(
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
served_model_name: String,
) -> Result<(BackendV2, BackendInfo), V2Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
@ -119,6 +120,7 @@ pub async fn connect_backend(
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
served_model_name,
);
tracing::info!("Using backend V3");

View File

@ -7,6 +7,9 @@ use thiserror::Error;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(long, env)]
served_model_name: String,
#[command(subcommand)]
command: Option<Commands>,
@ -83,8 +86,11 @@ enum Commands {
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
let _served_model_name = args.served_model_name.clone();
// Pattern match configuration
let Args {
served_model_name,
command,
max_concurrent_requests,
max_best_of,
@ -170,6 +176,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
served_model_name.clone(),
)
.await?;
@ -198,6 +205,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
served_model_name.clone(),
)
.await?;
Ok(())

View File

@ -43,6 +43,7 @@ impl Queue {
block_size: u32,
window_size: Option<u32>,
speculate: u32,
served_model_name: String,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -54,6 +55,7 @@ impl Queue {
window_size,
speculate,
queue_receiver,
served_model_name,
));
Self { queue_sender }
@ -104,6 +106,7 @@ async fn queue_task(
window_size: Option<u32>,
speculate: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
served_model_name: String,
) {
let mut state = State::new(requires_padding, block_size, window_size, speculate);
@ -111,7 +114,7 @@ async fn queue_task(
match cmd {
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry));
metrics::gauge!("tgi_queue_size").increment(1.0);
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).increment(1.0);
}
QueueCommand::NextBatch {
min_size,
@ -122,9 +125,9 @@ async fn queue_task(
span,
} => span.in_scope(|| {
let next_batch =
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
state.next_batch(min_size, max_size, prefill_token_budget, token_budget, served_model_name.clone());
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).set(state.entries.len() as f64);
}),
}
}
@ -191,6 +194,7 @@ impl State {
max_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
served_model_name: String,
) -> Option<NextBatch> {
if self.entries.is_empty() {
tracing::debug!("No queue");
@ -232,7 +236,7 @@ impl State {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
tracing::debug!("Dropping entry");
continue;
}
@ -340,7 +344,7 @@ impl State {
// Increment batch id
self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
metrics::histogram!("tgi_batch_next_size", "model_name" => served_model_name.clone()).record(batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}
@ -466,21 +470,23 @@ mod tests {
#[test]
fn test_next_batch_empty() {
let served_model_name = "bigscience/blomm-560m".to_string();
let mut state = State::new(false, 1, None, 0);
assert!(state.next_batch(None, None, 1, 1).is_none());
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
assert!(state.next_batch(None, None, 1, 1, served_model_name.clone()).is_none());
assert!(state.next_batch(Some(1), None, 1, 1, served_model_name.clone()).is_none());
}
#[test]
fn test_next_batch_min_size() {
let served_model_name = "bigscience/blomm-560m".to_string();
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
let (entries, batch, _) = state.next_batch(None, None, 2, 2, served_model_name.clone()).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
@ -496,7 +502,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
assert!(state.next_batch(Some(2), None, 2, 2, served_model_name.clone()).is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
@ -506,13 +512,14 @@ mod tests {
#[test]
fn test_next_batch_max_size() {
let served_model_name = "bigscience/blomm-560m".to_string();
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2, served_model_name.clone()).unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some());
@ -526,13 +533,14 @@ mod tests {
#[test]
fn test_next_batch_token_budget() {
let served_model_name = "bigscience/blomm-560m".to_string();
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
let (entries, batch, _) = state.next_batch(None, None, 1, 1, served_model_name.clone()).unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
@ -545,7 +553,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
let (entries, batch, _) = state.next_batch(None, None, 3, 3, served_model_name.clone()).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
@ -559,14 +567,16 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -574,7 +584,8 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -607,7 +618,8 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -623,7 +635,9 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -648,7 +662,9 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, None, 2, served_model_name.clone());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -667,7 +683,9 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
let (entry, _) = default_entry();
queue.append(entry);

View File

@ -34,6 +34,7 @@ impl BackendV3 {
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
shard_info: InfoResponse,
served_model_name: String,
) -> Self {
if shard_info.support_chunking {
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
@ -49,6 +50,7 @@ impl BackendV3 {
shard_info.speculate,
max_batch_total_tokens,
shard_info.support_chunking,
served_model_name.clone(),
);
let batching_task_notifier = Arc::new(Notify::new());
@ -63,6 +65,7 @@ impl BackendV3 {
shard_info.support_chunking,
queue.clone(),
batching_task_notifier.clone(),
served_model_name.clone(),
));
Self {
@ -136,6 +139,7 @@ pub(crate) async fn batching_task(
support_chunking: bool,
queue: Queue,
notifier: Arc<Notify>,
served_model_name: String,
) {
// Infinite loop
loop {
@ -154,7 +158,7 @@ pub(crate) async fn batching_task(
)
.await
{
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
let mut cached_batch = prefill(&mut client, batch, None, &mut entries, served_model_name.clone())
.instrument(span)
.await;
let mut waiting_tokens = 1;
@ -167,8 +171,8 @@ pub(crate) async fn batching_task(
let batch_max_tokens = batch.max_tokens;
let current_tokens = batch.current_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64);
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
@ -207,13 +211,13 @@ pub(crate) async fn batching_task(
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone())
.increment(1);
} else {
let counter = if support_chunking {
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
metrics::counter!("tgi_batch_concat", "reason" => "chunking", "model_name" => served_model_name.clone())
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone())
};
counter.increment(1);
}
@ -226,7 +230,7 @@ pub(crate) async fn batching_task(
entries.extend(new_entries);
// Generate one token for both the cached batch and the new batch
let new_cached_batch =
prefill(&mut client, new_batch, cached_batch, &mut entries)
prefill(&mut client, new_batch, cached_batch, &mut entries, served_model_name.clone())
.instrument(span)
.await;
if new_cached_batch.is_none() {
@ -250,7 +254,7 @@ pub(crate) async fn batching_task(
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, None, &mut new_entries)
prefill(&mut client, new_batch, None, &mut new_entries, served_model_name.clone())
.instrument(span)
.await;
if new_cached_batch.is_some() {
@ -282,13 +286,13 @@ pub(crate) async fn batching_task(
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone())
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0);
}
}
}
@ -299,40 +303,41 @@ async fn prefill(
batch: Batch,
cached_batch: Option<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
served_model_name: String,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
match client.prefill(batch, cached_batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
filter_send_generations(generations, entries, served_model_name.clone());
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill", "model_name" => served_model_name.clone())
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone())
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone())
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill", "model_name" => served_model_name.clone())
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
send_errors(err, entries, served_model_name.clone());
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
None
}
}
@ -343,33 +348,34 @@ async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
served_model_name: String,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
filter_send_generations(generations, entries, served_model_name.clone());
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
next_batch
}
// If we have an error, we discard the whole batch
@ -377,8 +383,8 @@ async fn decode(
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
send_errors(err, entries, served_model_name.clone());
metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
None
}
}
@ -420,7 +426,7 @@ async fn filter_batch(
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
@ -434,9 +440,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).inspect_err(|_err| {
let stopped = send_responses(generation, entry, served_model_name.clone()).inspect_err(|_err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
@ -448,10 +454,11 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
fn send_responses(
generation: Generation,
entry: &Entry,
served_model_name: String,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
return Ok(true);
}
@ -477,7 +484,7 @@ fn send_responses(
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
@ -537,12 +544,12 @@ fn send_responses(
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
metrics::counter!("tgi_request_failure", "err" => "generation", "model_name" => served_model_name.clone()).increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -54,6 +54,7 @@ pub async fn connect_backend(
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
served_model_name: String,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |(
@ -161,6 +162,7 @@ pub async fn connect_backend(
max_waiting_tokens,
max_batch_size,
shard_info,
served_model_name,
);
tracing::info!("Using backend V3");

View File

@ -7,6 +7,9 @@ use thiserror::Error;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(long, env)]
served_model_name: String,
#[command(subcommand)]
command: Option<Commands>,
@ -74,6 +77,7 @@ struct Args {
payload_limit: usize,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
@ -83,8 +87,11 @@ enum Commands {
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
let _served_model_name = args.served_model_name.clone();
// Pattern match configuration
let Args {
served_model_name,
command,
max_concurrent_requests,
max_best_of,
@ -151,6 +158,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
served_model_name.clone(),
)
.await?;
@ -214,6 +222,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
served_model_name.clone(),
)
.await?;
Ok(())

View File

@ -51,6 +51,7 @@ impl Queue {
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
served_model_name: String,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -65,6 +66,7 @@ impl Queue {
max_batch_total_tokens,
support_chunking,
queue_receiver,
served_model_name,
));
Self { queue_sender }
@ -124,6 +126,7 @@ async fn queue_task(
max_batch_total_tokens: u32,
support_chunking: bool,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
served_model_name: String,
) {
let mut state = State::new(
requires_padding,
@ -139,7 +142,7 @@ async fn queue_task(
match cmd {
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry));
metrics::gauge!("tgi_queue_size").increment(1.0);
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).increment(1.0);
}
QueueCommand::NextBatch {
min_size,
@ -150,11 +153,11 @@ async fn queue_task(
span,
} => {
let next_batch = state
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.next_batch(min_size, max_size, prefill_token_budget, token_budget, served_model_name.clone())
.instrument(span)
.await;
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).set(state.entries.len() as f64);
}
}
}
@ -235,6 +238,7 @@ impl State {
max_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
served_model_name: String,
) -> Option<NextBatch> {
if self.entries.is_empty() {
tracing::debug!("No queue");
@ -274,7 +278,7 @@ impl State {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
tracing::debug!("Dropping entry");
continue;
}
@ -478,7 +482,7 @@ impl State {
// Increment batch id
self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
metrics::histogram!("tgi_batch_next_size", "model_name" => served_model_name.clone()).record(batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}
@ -606,21 +610,24 @@ mod tests {
#[tokio::test]
async fn test_next_batch_empty() {
let served_model_name = "bigscience/blomm-560m".to_string();
let mut state = State::new(false, 1, false, None, 0, 16, false);
assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
assert!(state.next_batch(None, None, 1, 1, served_model_name.clone()).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1, served_model_name.clone()).await.is_none());
}
#[tokio::test]
async fn test_next_batch_min_size() {
let served_model_name = "bigscience/blomm-560m".to_string();
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
let (entries, batch, _) = state.next_batch(None, None, 2, 2, served_model_name.clone()).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
@ -636,7 +643,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
assert!(state.next_batch(Some(2), None, 2, 2, served_model_name.clone()).await.is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
@ -646,13 +653,14 @@ mod tests {
#[tokio::test]
async fn test_next_batch_max_size() {
let served_model_name = "bigscience/blomm-560m".to_string();
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2, served_model_name.clone()).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some());
@ -666,13 +674,14 @@ mod tests {
#[tokio::test]
async fn test_next_batch_token_budget() {
let served_model_name = "bigscience/blomm-560m".to_string();
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
let (entries, batch, _) = state.next_batch(None, None, 1, 1, served_model_name.clone()).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
@ -685,7 +694,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
let (entries, batch, _) = state.next_batch(None, None, 3, 3, served_model_name.clone()).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
@ -699,14 +708,16 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -714,7 +725,8 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -747,7 +759,8 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -763,7 +776,8 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -788,7 +802,8 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(true, 1, false, None, 2, 16, false);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(true, 1, false, None, 2, 16, false, served_model_name.clone());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -807,7 +822,8 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let served_model_name = "bigscience/blomm-560m".to_string();
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
let (entry, _) = default_entry();
queue.append(entry);

View File

@ -560,6 +560,10 @@ struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)]
model_id: String,
/// Name under which the model is served. Defaults to `model_id` if not provided.
#[clap(long, env)]
served_model_name: Option<String>,
/// The actual revision of the model if you're referring to a model
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
#[clap(long, env)]
@ -1802,7 +1806,7 @@ fn spawn_webserver(
"--master-shard-uds-path".to_string(),
format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(),
args.model_id,
args.model_id.clone(),
"--payload-limit".to_string(),
args.payload_limit.to_string(),
];
@ -1973,6 +1977,12 @@ fn main() -> Result<(), LauncherError> {
// Pattern match configuration
let args: Args = Args::parse();
let served_model_name = args.served_model_name
.clone()
.unwrap_or_else(|| args.model_id.clone());
env::set_var("SERVED_MODEL_NAME", &served_model_name);
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {

View File

@ -100,6 +100,7 @@ impl Infer {
pub(crate) async fn generate_stream<'a>(
&'a self,
request: GenerateRequest,
served_model_name: String,
) -> Result<
(
OwnedSemaphorePermit,
@ -121,7 +122,7 @@ impl Infer {
// Validate request
let mut local_request = request.clone();
let valid_request = self.validation.validate(request).await.map_err(|err| {
let valid_request = self.validation.validate(request, served_model_name.clone()).await.map_err(|err| {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
err
@ -165,7 +166,7 @@ impl Infer {
local_request.inputs.push_str(&generated_text.text);
all_generated_text = all_generated_text.or(Some(generated_text));
let valid_request = match self.validation.validate(local_request.clone()).await {
let valid_request = match self.validation.validate(local_request.clone(), served_model_name.clone()).await {
Ok(valid_request) => valid_request,
Err(err) => {
tracing::debug!("Failed to continue request: {err}");
@ -245,11 +246,12 @@ impl Infer {
pub(crate) async fn generate(
&self,
request: GenerateRequest,
served_model_name: String,
) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
let (_permit, _input_length, stream) = self.generate_stream(request, served_model_name).await?;
// Return values
let mut result_prefill = Vec::new();
@ -322,13 +324,14 @@ impl Infer {
&self,
request: GenerateRequest,
best_of: usize,
served_model_name: String,
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;
// create multiple generate requests
let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
try_join_all((0..best_of).map(|_| self.generate(request.clone(), served_model_name.clone()))).await?;
// get the sequence with the highest log probability per token
let mut max_index = 0;

View File

@ -68,15 +68,16 @@ pub(crate) async fn sagemaker_compatibility(
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
info: Extension<Info>,
served_model_name: Extension<String>,
Json(req): Json<SagemakerRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
match req {
SagemakerRequest::Generate(req) => {
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
compat_generate(default_return_full_text, infer, compute_type, served_model_name, Json(req)).await
}
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, served_model_name, Json(req)).await,
SagemakerRequest::Completion(req) => {
completions(infer, compute_type, info, Json(req)).await
completions(infer, compute_type, info, served_model_name, Json(req)).await
}
}
}

View File

@ -129,6 +129,7 @@ pub(crate) async fn compat_generate(
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
served_model_name: Extension<String>,
Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// default return_full_text given the pipeline_tag
@ -138,11 +139,11 @@ pub(crate) async fn compat_generate(
// switch on stream
if req.stream {
Ok(generate_stream(infer, compute_type, Json(req.into()))
Ok(generate_stream(infer, served_model_name.clone(), compute_type, Json(req.into()))
.await
.into_response())
} else {
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
let (headers, Json(generation)) = generate(infer, served_model_name.clone(), compute_type, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response())
}
@ -196,9 +197,10 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
)]
async fn get_chat_tokenize(
Extension(infer): Extension<Infer>,
Extension(served_model_name): Extension<String>,
Json(chat): Json<ChatRequest>,
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
metrics::counter!("tgi_request_count").increment(1);
metrics::counter!("tgi_request_count", "model_name" => served_model_name).increment(1);
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
let input = generate_request.inputs.clone();
@ -270,23 +272,25 @@ seed,
)]
async fn generate(
infer: Extension<Infer>,
served_model_name: Extension<String>,
Extension(ComputeType(compute_type)): Extension<ComputeType>,
Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let (headers, _, response) =
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
generate_internal(infer, served_model_name, ComputeType(compute_type), Json(req), span).await?;
Ok((headers, response))
}
pub(crate) async fn generate_internal(
infer: Extension<Infer>,
served_model_name: Extension<String>,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
span: tracing::Span,
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1);
metrics::counter!("tgi_request_count", "model_name" => served_model_name.0.clone()).increment(1);
// Do not long ultra long inputs, like image payloads.
tracing::debug!(
@ -305,10 +309,10 @@ pub(crate) async fn generate_internal(
// Inference
let (response, best_of_responses) = match req.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
let (response, best_of_responses) = infer.generate_best_of(req, best_of, served_model_name.0.clone()).await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req).await?, None),
_ => (infer.generate(req, served_model_name.0.clone()).await?, None),
};
// Token details
@ -405,14 +409,14 @@ pub(crate) async fn generate_internal(
);
// Metrics
metrics::counter!("tgi_request_success").increment(1);
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration")
metrics::counter!("tgi_request_success", "model_name" => served_model_name.0.clone()).increment(1);
metrics::histogram!("tgi_request_duration", "model_name" => served_model_name.0.clone()).record(total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration", "model_name" => served_model_name.0.clone()).record(validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration", "model_name" => served_model_name.0.clone()).record(queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration", "model_name" => served_model_name.0.clone()).record(inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration", "model_name" => served_model_name.0.clone())
.record(time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens")
metrics::histogram!("tgi_request_generated_tokens", "model_name" => served_model_name.0.clone())
.record(response.generated_text.generated_tokens as f64);
// Send response
@ -468,6 +472,7 @@ seed,
)]
async fn generate_stream(
Extension(infer): Extension<Infer>,
Extension(served_model_name): Extension<String>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<GenerateRequest>,
) -> (
@ -476,7 +481,7 @@ async fn generate_stream(
) {
let span = tracing::Span::current();
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(req), span).await;
generate_stream_internal(infer, served_model_name, compute_type, Json(req), span).await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
@ -495,6 +500,7 @@ async fn generate_stream(
async fn generate_stream_internal(
infer: Infer,
served_model_name: String,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
span: tracing::Span,
@ -503,7 +509,7 @@ async fn generate_stream_internal(
impl Stream<Item = Result<StreamResponse, InferError>>,
) {
let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1);
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
tracing::debug!("Input: {}", req.inputs);
@ -540,7 +546,7 @@ async fn generate_stream_internal(
tracing::error!("{err}");
yield Err(err);
} else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
match infer.generate_stream(req, served_model_name.clone()).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, input_length, response_stream)) => {
let mut index = 0;
@ -605,13 +611,13 @@ async fn generate_stream_internal(
span.record("seed", format!("{:?}", generated_text.seed));
// Metrics
metrics::counter!("tgi_request_success").increment(1);
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64);
metrics::counter!("tgi_request_success", "model_name" => served_model_name.clone()).increment(1);
metrics::histogram!("tgi_request_duration", "model_name" => served_model_name.clone()).record(total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration", "model_name" => served_model_name.clone()).record(validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration", "model_name" => served_model_name.clone()).record(queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration", "model_name" => served_model_name.clone()).record(inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration", "model_name" => served_model_name.clone()).record(time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens", "model_name" => served_model_name.clone()).record(generated_text.generated_tokens as f64);
// StreamResponse
end_reached = true;
@ -704,10 +710,11 @@ pub(crate) async fn completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Extension(served_model_name): Extension<String>,
Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
let CompletionRequest {
model,
@ -798,6 +805,7 @@ pub(crate) async fn completions(
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
let served_model_name_clone = served_model_name.clone();
// Create a future for each generate_stream_internal call.
let generate_future = async move {
@ -807,6 +815,7 @@ pub(crate) async fn completions(
tokio::spawn(async move {
let (headers, response_stream) = generate_stream_internal(
infer_clone.clone(),
served_model_name_clone.clone(),
compute_type_clone.clone(),
Json(generate_request),
span_clone.clone(),
@ -975,11 +984,13 @@ pub(crate) async fn completions(
let responses = FuturesUnordered::new();
for (index, generate_request) in generate_requests.into_iter().enumerate() {
let infer_clone = infer.clone();
let served_model_name_clone = served_model_name.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
let response_future = async move {
let result = generate_internal(
Extension(infer_clone),
Extension(served_model_name_clone),
compute_type_clone,
Json(generate_request),
span_clone,
@ -1230,10 +1241,11 @@ pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Extension(served_model_name): Extension<String>,
Json(chat): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
let ChatRequest {
model,
stream,
@ -1255,7 +1267,7 @@ pub(crate) async fn chat_completions(
// switch on stream
if stream {
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
generate_stream_internal(infer, served_model_name, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
@ -1389,7 +1401,7 @@ pub(crate) async fn chat_completions(
Ok((headers, sse).into_response())
} else {
let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
generate_internal(Extension(infer), Extension(served_model_name), compute_type, Json(generate_request), span).await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -1688,6 +1700,7 @@ pub async fn run(
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
payload_limit: usize,
served_model_name: String,
) -> Result<(), WebServerError> {
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
@ -1963,6 +1976,7 @@ pub async fn run(
compat_return_full_text,
allow_origin,
payload_limit,
served_model_name,
)
.await;
@ -2024,6 +2038,7 @@ async fn start(
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
payload_limit: usize,
served_model_name: String,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
@ -2076,22 +2091,22 @@ async fn start(
duration_buckets.push(value);
}
// Input Length buckets
let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length"));
let input_length_matcher = Matcher::Full(format!("tgi_request_input_length{{model_name=\"{}\"}}", served_model_name));
let input_length_buckets: Vec<f64> = (0..100)
.map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64)
.collect();
// Generated tokens buckets
let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
let generated_tokens_matcher = Matcher::Full(format!("tgi_request_generated_tokens{{model_name=\"{}\"}}", served_model_name));
let generated_tokens_buckets: Vec<f64> = (0..100)
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
.collect();
// Input Length buckets
let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens"));
let max_new_tokens_matcher = Matcher::Full(format!("tgi_request_max_new_tokens{{model_name=\"{}\"}}", served_model_name));
let max_new_tokens_buckets: Vec<f64> = (0..100)
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
.collect();
// Batch size buckets
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_matcher = Matcher::Full(format!("ttgi_batch_next_size{{model_name=\"{}\"}}", served_model_name));
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Speculated tokens buckets
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
@ -2334,7 +2349,8 @@ async fn start(
.route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility))
.route("/invocations", post(sagemaker_compatibility))
.route("/tokenize", post(tokenize));
.route("/tokenize", post(tokenize))
.layer(Extension(served_model_name));
if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();

View File

@ -133,6 +133,7 @@ impl Validation {
add_special_tokens: bool,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
served_model_name: String,
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {
// If we have a fast tokenizer
let (encoding, inputs) = self
@ -186,7 +187,7 @@ impl Validation {
let ids = encoding.get_ids();
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
metrics::histogram!("tgi_request_input_length", "model_name" => served_model_name.clone()).record(input_length as f64);
Ok((
inputs,
Some(input_ids),
@ -201,6 +202,7 @@ impl Validation {
pub(crate) async fn validate(
&self,
request: GenerateRequest,
served_model_name: String,
) -> Result<ValidGenerateRequest, ValidationError> {
let GenerateParameters {
best_of,
@ -332,6 +334,7 @@ impl Validation {
request.add_special_tokens,
truncate,
max_new_tokens,
served_model_name.clone(),
)
.await?;
@ -405,7 +408,7 @@ impl Validation {
ignore_eos_token: false,
};
metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64);
metrics::histogram!("tgi_request_max_new_tokens", "model_name" => served_model_name.clone()).record(max_new_tokens as f64);
Ok(ValidGenerateRequest {
inputs,
@ -953,10 +956,10 @@ mod tests {
max_total_tokens,
disable_grammar_support,
);
let served_model_name = "bigscience/blomm-560m".to_string();
let max_new_tokens = 10;
match validation
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens), served_model_name)
.await
{
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
@ -989,9 +992,10 @@ mod tests {
disable_grammar_support,
);
let served_model_name = "bigscience/blomm-560m".to_string();
let max_new_tokens = 10;
match validation
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens), served_model_name)
.await
{
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
@ -1022,6 +1026,7 @@ mod tests {
max_total_tokens,
disable_grammar_support,
);
let served_model_name = "bigscience/blomm-560m".to_string();
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
@ -1031,7 +1036,7 @@ mod tests {
do_sample: false,
..default_parameters()
},
})
}, served_model_name)
.await
{
Err(ValidationError::BestOfSampling) => (),
@ -1062,6 +1067,7 @@ mod tests {
max_total_tokens,
disable_grammar_support,
);
let served_model_name = "bigscience/blomm-560m".to_string();
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
@ -1071,7 +1077,7 @@ mod tests {
max_new_tokens: Some(5),
..default_parameters()
},
})
}, served_model_name.clone())
.await
{
Err(ValidationError::TopP) => (),
@ -1087,7 +1093,7 @@ mod tests {
max_new_tokens: Some(5),
..default_parameters()
},
})
}, served_model_name.clone())
.await
{
Ok(_) => (),
@ -1103,7 +1109,7 @@ mod tests {
max_new_tokens: Some(5),
..default_parameters()
},
})
}, served_model_name.clone())
.await
.unwrap();
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
@ -1133,6 +1139,7 @@ mod tests {
max_total_tokens,
disable_grammar_support,
);
let served_model_name = "bigscience/blomm-560m".to_string();
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
@ -1142,7 +1149,7 @@ mod tests {
max_new_tokens: Some(5),
..default_parameters()
},
})
}, served_model_name.clone())
.await
{
Err(ValidationError::TopNTokens(4, 5)) => (),
@ -1158,7 +1165,7 @@ mod tests {
max_new_tokens: Some(5),
..default_parameters()
},
})
}, served_model_name.clone())
.await
.unwrap();
@ -1171,7 +1178,7 @@ mod tests {
max_new_tokens: Some(5),
..default_parameters()
},
})
}, served_model_name.clone())
.await
.unwrap();
@ -1184,7 +1191,7 @@ mod tests {
max_new_tokens: Some(5),
..default_parameters()
},
})
}, served_model_name.clone())
.await
.unwrap();

View File

@ -69,11 +69,12 @@ example = json ! ({"error": "Incomplete generation"})),
)]
pub(crate) async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Extension(served_model_name): Extension<String>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
// check that theres at least one instance
if req.instances.is_empty() {
@ -111,12 +112,14 @@ pub(crate) async fn vertex_compatibility(
};
let infer_clone = infer.clone();
let served_model_name_clone = served_model_name.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer_clone),
Extension(served_model_name_clone),
compute_type_clone,
Json(generate_request),
span_clone,