mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
Added model name label to metrics and added an optional argument --served-model-name
This commit is contained in:
parent
5eec3a8bb6
commit
380e73dba9
@ -19,6 +19,10 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
model_id: String,
|
||||
|
||||
/// Name under which the model is served. Defaults to `model_id` if not provided.
|
||||
#[clap(long, env)]
|
||||
served_model_name: Option<String>,
|
||||
|
||||
/// Revision of the model.
|
||||
#[clap(default_value = "main", long, env)]
|
||||
revision: String,
|
||||
@ -152,6 +156,10 @@ struct Args {
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
let args = Args::parse();
|
||||
|
||||
let served_model_name = args.served_model_name
|
||||
.clone()
|
||||
.unwrap_or_else(|| args.model_id.clone());
|
||||
|
||||
logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
|
||||
|
||||
let n_threads = match args.n_threads {
|
||||
@ -264,6 +272,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
args.max_client_batch_size,
|
||||
args.usage_stats,
|
||||
args.payload_limit,
|
||||
served_model_name
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -45,6 +45,8 @@ struct Args {
|
||||
revision: Option<String>,
|
||||
#[clap(long, env)]
|
||||
model_id: String,
|
||||
#[clap(long, env)]
|
||||
served_model_name: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
@ -227,6 +229,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
model_id,
|
||||
served_model_name,
|
||||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
@ -239,6 +242,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
payload_limit,
|
||||
} = args;
|
||||
|
||||
let served_model_name = args.served_model_name
|
||||
.clone()
|
||||
.unwrap_or_else(|| args.model_id.clone());
|
||||
|
||||
// Launch Tokio runtime
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
@ -318,6 +325,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
served_model_name,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -34,6 +34,7 @@ impl BackendV2 {
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
served_model_name: String,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
|
||||
@ -44,7 +45,7 @@ impl BackendV2 {
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate, served_model_name.clone());
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
@ -57,6 +58,7 @@ impl BackendV2 {
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
served_model_name.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
@ -128,6 +130,7 @@ pub(crate) async fn batching_task(
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
served_model_name: String,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
@ -146,7 +149,7 @@ pub(crate) async fn batching_task(
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, served_model_name.clone())
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
@ -158,8 +161,8 @@ pub(crate) async fn batching_task(
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
@ -180,10 +183,10 @@ pub(crate) async fn batching_task(
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone())
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone())
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
@ -199,7 +202,7 @@ pub(crate) async fn batching_task(
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries, served_model_name.clone())
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
@ -225,13 +228,13 @@ pub(crate) async fn batching_task(
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone())
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -241,36 +244,37 @@ async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
served_model_name: String,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
filter_send_generations(generations, entries, served_model_name.clone());
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill", "model_name" => served_model_name.clone())
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill", "model_name" => served_model_name.clone())
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
@ -281,33 +285,34 @@ async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
served_model_name: String,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
filter_send_generations(generations, entries, served_model_name.clone());
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
@ -316,7 +321,7 @@ async fn decode(
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
@ -358,7 +363,7 @@ async fn filter_batch(
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
@ -372,9 +377,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||
let stopped = send_responses(generation, entry, served_model_name.clone()).inspect_err(|_err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
@ -386,10 +391,11 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
served_model_name: String,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
@ -415,7 +421,7 @@ fn send_responses(
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
|
@ -39,6 +39,7 @@ pub async fn connect_backend(
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
served_model_name: String,
|
||||
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
@ -119,6 +120,7 @@ pub async fn connect_backend(
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
served_model_name,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
@ -7,6 +7,9 @@ use thiserror::Error;
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(long, env)]
|
||||
served_model_name: String,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
@ -83,8 +86,11 @@ enum Commands {
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
let _served_model_name = args.served_model_name.clone();
|
||||
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
served_model_name,
|
||||
command,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
@ -170,6 +176,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
served_model_name.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@ -198,6 +205,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
served_model_name.clone(),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -43,6 +43,7 @@ impl Queue {
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
served_model_name: String,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
@ -54,6 +55,7 @@ impl Queue {
|
||||
window_size,
|
||||
speculate,
|
||||
queue_receiver,
|
||||
served_model_name,
|
||||
));
|
||||
|
||||
Self { queue_sender }
|
||||
@ -104,6 +106,7 @@ async fn queue_task(
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
served_model_name: String,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||
|
||||
@ -111,7 +114,7 @@ async fn queue_task(
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).increment(1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
@ -122,9 +125,9 @@ async fn queue_task(
|
||||
span,
|
||||
} => span.in_scope(|| {
|
||||
let next_batch =
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget, served_model_name.clone());
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).set(state.entries.len() as f64);
|
||||
}),
|
||||
}
|
||||
}
|
||||
@ -191,6 +194,7 @@ impl State {
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
served_model_name: String,
|
||||
) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
tracing::debug!("No queue");
|
||||
@ -232,7 +236,7 @@ impl State {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
@ -340,7 +344,7 @@ impl State {
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||
metrics::histogram!("tgi_batch_next_size", "model_name" => served_model_name.clone()).record(batch.size as f64);
|
||||
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
@ -466,21 +470,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
||||
assert!(state.next_batch(None, None, 1, 1, served_model_name.clone()).is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1, served_model_name.clone()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2, served_model_name.clone()).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -496,7 +502,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
||||
assert!(state.next_batch(Some(2), None, 2, 2, served_model_name.clone()).is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
@ -506,13 +512,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2, served_model_name.clone()).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
@ -526,13 +533,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1, served_model_name.clone()).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -545,7 +553,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3, served_model_name.clone()).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
@ -559,14 +567,16 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
@ -574,7 +584,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -607,7 +618,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -623,7 +635,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
|
||||
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -648,7 +662,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, None, 2);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
|
||||
let queue = Queue::new(false, 1, None, 2, served_model_name.clone());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -667,7 +683,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
|
||||
let queue = Queue::new(false, 1, None, 0, served_model_name.clone());
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
@ -34,6 +34,7 @@ impl BackendV3 {
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
shard_info: InfoResponse,
|
||||
served_model_name: String,
|
||||
) -> Self {
|
||||
if shard_info.support_chunking {
|
||||
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||
@ -49,6 +50,7 @@ impl BackendV3 {
|
||||
shard_info.speculate,
|
||||
max_batch_total_tokens,
|
||||
shard_info.support_chunking,
|
||||
served_model_name.clone(),
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
@ -63,6 +65,7 @@ impl BackendV3 {
|
||||
shard_info.support_chunking,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
served_model_name.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
@ -136,6 +139,7 @@ pub(crate) async fn batching_task(
|
||||
support_chunking: bool,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
served_model_name: String,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
@ -154,7 +158,7 @@ pub(crate) async fn batching_task(
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
||||
let mut cached_batch = prefill(&mut client, batch, None, &mut entries, served_model_name.clone())
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
@ -167,8 +171,8 @@ pub(crate) async fn batching_task(
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let current_tokens = batch.current_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64);
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
|
||||
@ -207,13 +211,13 @@ pub(crate) async fn batching_task(
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone())
|
||||
.increment(1);
|
||||
} else {
|
||||
let counter = if support_chunking {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "chunking", "model_name" => served_model_name.clone())
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone())
|
||||
};
|
||||
counter.increment(1);
|
||||
}
|
||||
@ -226,7 +230,7 @@ pub(crate) async fn batching_task(
|
||||
entries.extend(new_entries);
|
||||
// Generate one token for both the cached batch and the new batch
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
||||
prefill(&mut client, new_batch, cached_batch, &mut entries, served_model_name.clone())
|
||||
.instrument(span)
|
||||
.await;
|
||||
if new_cached_batch.is_none() {
|
||||
@ -250,7 +254,7 @@ pub(crate) async fn batching_task(
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, None, &mut new_entries)
|
||||
prefill(&mut client, new_batch, None, &mut new_entries, served_model_name.clone())
|
||||
.instrument(span)
|
||||
.await;
|
||||
if new_cached_batch.is_some() {
|
||||
@ -282,13 +286,13 @@ pub(crate) async fn batching_task(
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone())
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -299,40 +303,41 @@ async fn prefill(
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
served_model_name: String,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||
|
||||
match client.prefill(batch, cached_batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
filter_send_generations(generations, entries, served_model_name.clone());
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill", "model_name" => served_model_name.clone())
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
send_errors(err, entries, served_model_name.clone());
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
@ -343,33 +348,34 @@ async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
served_model_name: String,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
filter_send_generations(generations, entries, served_model_name.clone());
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone())
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
@ -377,8 +383,8 @@ async fn decode(
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
send_errors(err, entries, served_model_name.clone());
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
@ -420,7 +426,7 @@ async fn filter_batch(
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
@ -434,9 +440,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||
let stopped = send_responses(generation, entry, served_model_name.clone()).inspect_err(|_err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
@ -448,10 +454,11 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
served_model_name: String,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
@ -477,7 +484,7 @@ fn send_responses(
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
@ -537,12 +544,12 @@ fn send_responses(
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation", "model_name" => served_model_name.clone()).increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
|
@ -54,6 +54,7 @@ pub async fn connect_backend(
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
served_model_name: String,
|
||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |(
|
||||
@ -161,6 +162,7 @@ pub async fn connect_backend(
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info,
|
||||
served_model_name,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
@ -7,6 +7,9 @@ use thiserror::Error;
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(long, env)]
|
||||
served_model_name: String,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
@ -74,6 +77,7 @@ struct Args {
|
||||
payload_limit: usize,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
@ -83,8 +87,11 @@ enum Commands {
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
let _served_model_name = args.served_model_name.clone();
|
||||
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
served_model_name,
|
||||
command,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
@ -151,6 +158,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
served_model_name.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@ -214,6 +222,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
served_model_name.clone(),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -51,6 +51,7 @@ impl Queue {
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
served_model_name: String,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
@ -65,6 +66,7 @@ impl Queue {
|
||||
max_batch_total_tokens,
|
||||
support_chunking,
|
||||
queue_receiver,
|
||||
served_model_name,
|
||||
));
|
||||
|
||||
Self { queue_sender }
|
||||
@ -124,6 +126,7 @@ async fn queue_task(
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
served_model_name: String,
|
||||
) {
|
||||
let mut state = State::new(
|
||||
requires_padding,
|
||||
@ -139,7 +142,7 @@ async fn queue_task(
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).increment(1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
@ -150,11 +153,11 @@ async fn queue_task(
|
||||
span,
|
||||
} => {
|
||||
let next_batch = state
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget, served_model_name.clone())
|
||||
.instrument(span)
|
||||
.await;
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
metrics::gauge!("tgi_queue_size", "model_name" => served_model_name.clone()).set(state.entries.len() as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -235,6 +238,7 @@ impl State {
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
served_model_name: String,
|
||||
) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
tracing::debug!("No queue");
|
||||
@ -274,7 +278,7 @@ impl State {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
@ -478,7 +482,7 @@ impl State {
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||
metrics::histogram!("tgi_batch_next_size", "model_name" => served_model_name.clone()).record(batch.size as f64);
|
||||
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
@ -606,21 +610,24 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_empty() {
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(None, None, 1, 1, served_model_name.clone()).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1, served_model_name.clone()).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_min_size() {
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2, served_model_name.clone()).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -636,7 +643,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||
assert!(state.next_batch(Some(2), None, 2, 2, served_model_name.clone()).await.is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
@ -646,13 +653,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_max_size() {
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2, served_model_name.clone()).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
@ -666,13 +674,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1, served_model_name.clone()).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -685,7 +694,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3, served_model_name.clone()).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
@ -699,14 +708,16 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
@ -714,7 +725,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -747,7 +759,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -763,7 +776,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -788,7 +802,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(true, 1, false, None, 2, 16, false);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(true, 1, false, None, 2, 16, false, served_model_name.clone());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -807,7 +822,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false, served_model_name.clone());
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
@ -560,6 +560,10 @@ struct Args {
|
||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||
model_id: String,
|
||||
|
||||
/// Name under which the model is served. Defaults to `model_id` if not provided.
|
||||
#[clap(long, env)]
|
||||
served_model_name: Option<String>,
|
||||
|
||||
/// The actual revision of the model if you're referring to a model
|
||||
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
||||
#[clap(long, env)]
|
||||
@ -1802,7 +1806,7 @@ fn spawn_webserver(
|
||||
"--master-shard-uds-path".to_string(),
|
||||
format!("{}-0", args.shard_uds_path),
|
||||
"--tokenizer-name".to_string(),
|
||||
args.model_id,
|
||||
args.model_id.clone(),
|
||||
"--payload-limit".to_string(),
|
||||
args.payload_limit.to_string(),
|
||||
];
|
||||
@ -1973,6 +1977,12 @@ fn main() -> Result<(), LauncherError> {
|
||||
// Pattern match configuration
|
||||
let args: Args = Args::parse();
|
||||
|
||||
let served_model_name = args.served_model_name
|
||||
.clone()
|
||||
.unwrap_or_else(|| args.model_id.clone());
|
||||
|
||||
env::set_var("SERVED_MODEL_NAME", &served_model_name);
|
||||
|
||||
// Filter events with LOG_LEVEL
|
||||
let varname = "LOG_LEVEL";
|
||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||
|
@ -100,6 +100,7 @@ impl Infer {
|
||||
pub(crate) async fn generate_stream<'a>(
|
||||
&'a self,
|
||||
request: GenerateRequest,
|
||||
served_model_name: String,
|
||||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
@ -121,7 +122,7 @@ impl Infer {
|
||||
|
||||
// Validate request
|
||||
let mut local_request = request.clone();
|
||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||
let valid_request = self.validation.validate(request, served_model_name.clone()).await.map_err(|err| {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
@ -165,7 +166,7 @@ impl Infer {
|
||||
local_request.inputs.push_str(&generated_text.text);
|
||||
all_generated_text = all_generated_text.or(Some(generated_text));
|
||||
|
||||
let valid_request = match self.validation.validate(local_request.clone()).await {
|
||||
let valid_request = match self.validation.validate(local_request.clone(), served_model_name.clone()).await {
|
||||
Ok(valid_request) => valid_request,
|
||||
Err(err) => {
|
||||
tracing::debug!("Failed to continue request: {err}");
|
||||
@ -245,11 +246,12 @@ impl Infer {
|
||||
pub(crate) async fn generate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
served_model_name: String,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
|
||||
let (_permit, _input_length, stream) = self.generate_stream(request, served_model_name).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
@ -322,13 +324,14 @@ impl Infer {
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
best_of: usize,
|
||||
served_model_name: String,
|
||||
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
||||
// validate best_of parameter separately
|
||||
let best_of = self.validation.validate_best_of(best_of)?;
|
||||
|
||||
// create multiple generate requests
|
||||
let mut infer_responses: Vec<InferResponse> =
|
||||
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
||||
try_join_all((0..best_of).map(|_| self.generate(request.clone(), served_model_name.clone()))).await?;
|
||||
|
||||
// get the sequence with the highest log probability per token
|
||||
let mut max_index = 0;
|
||||
|
@ -68,15 +68,16 @@ pub(crate) async fn sagemaker_compatibility(
|
||||
infer: Extension<Infer>,
|
||||
compute_type: Extension<ComputeType>,
|
||||
info: Extension<Info>,
|
||||
served_model_name: Extension<String>,
|
||||
Json(req): Json<SagemakerRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
match req {
|
||||
SagemakerRequest::Generate(req) => {
|
||||
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
|
||||
compat_generate(default_return_full_text, infer, compute_type, served_model_name, Json(req)).await
|
||||
}
|
||||
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
|
||||
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, served_model_name, Json(req)).await,
|
||||
SagemakerRequest::Completion(req) => {
|
||||
completions(infer, compute_type, info, Json(req)).await
|
||||
completions(infer, compute_type, info, served_model_name, Json(req)).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -129,6 +129,7 @@ pub(crate) async fn compat_generate(
|
||||
Extension(default_return_full_text): Extension<bool>,
|
||||
infer: Extension<Infer>,
|
||||
compute_type: Extension<ComputeType>,
|
||||
served_model_name: Extension<String>,
|
||||
Json(mut req): Json<CompatGenerateRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
// default return_full_text given the pipeline_tag
|
||||
@ -138,11 +139,11 @@ pub(crate) async fn compat_generate(
|
||||
|
||||
// switch on stream
|
||||
if req.stream {
|
||||
Ok(generate_stream(infer, compute_type, Json(req.into()))
|
||||
Ok(generate_stream(infer, served_model_name.clone(), compute_type, Json(req.into()))
|
||||
.await
|
||||
.into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
|
||||
let (headers, Json(generation)) = generate(infer, served_model_name.clone(), compute_type, Json(req.into())).await?;
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation])).into_response())
|
||||
}
|
||||
@ -196,9 +197,10 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
||||
)]
|
||||
async fn get_chat_tokenize(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(served_model_name): Extension<String>,
|
||||
Json(chat): Json<ChatRequest>,
|
||||
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
metrics::counter!("tgi_request_count", "model_name" => served_model_name).increment(1);
|
||||
|
||||
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
||||
let input = generate_request.inputs.clone();
|
||||
@ -270,23 +272,25 @@ seed,
|
||||
)]
|
||||
async fn generate(
|
||||
infer: Extension<Infer>,
|
||||
served_model_name: Extension<String>,
|
||||
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
let (headers, _, response) =
|
||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
|
||||
generate_internal(infer, served_model_name, ComputeType(compute_type), Json(req), span).await?;
|
||||
Ok((headers, response))
|
||||
}
|
||||
|
||||
pub(crate) async fn generate_internal(
|
||||
infer: Extension<Infer>,
|
||||
served_model_name: Extension<String>,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
span: tracing::Span,
|
||||
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let start_time = Instant::now();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
metrics::counter!("tgi_request_count", "model_name" => served_model_name.0.clone()).increment(1);
|
||||
|
||||
// Do not long ultra long inputs, like image payloads.
|
||||
tracing::debug!(
|
||||
@ -305,10 +309,10 @@ pub(crate) async fn generate_internal(
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.parameters.best_of {
|
||||
Some(best_of) if best_of > 1 => {
|
||||
let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
|
||||
let (response, best_of_responses) = infer.generate_best_of(req, best_of, served_model_name.0.clone()).await?;
|
||||
(response, Some(best_of_responses))
|
||||
}
|
||||
_ => (infer.generate(req).await?, None),
|
||||
_ => (infer.generate(req, served_model_name.0.clone()).await?, None),
|
||||
};
|
||||
|
||||
// Token details
|
||||
@ -405,14 +409,14 @@ pub(crate) async fn generate_internal(
|
||||
);
|
||||
|
||||
// Metrics
|
||||
metrics::counter!("tgi_request_success").increment(1);
|
||||
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration")
|
||||
metrics::counter!("tgi_request_success", "model_name" => served_model_name.0.clone()).increment(1);
|
||||
metrics::histogram!("tgi_request_duration", "model_name" => served_model_name.0.clone()).record(total_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_validation_duration", "model_name" => served_model_name.0.clone()).record(validation_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_queue_duration", "model_name" => served_model_name.0.clone()).record(queue_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_inference_duration", "model_name" => served_model_name.0.clone()).record(inference_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", "model_name" => served_model_name.0.clone())
|
||||
.record(time_per_token.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_generated_tokens")
|
||||
metrics::histogram!("tgi_request_generated_tokens", "model_name" => served_model_name.0.clone())
|
||||
.record(response.generated_text.generated_tokens as f64);
|
||||
|
||||
// Send response
|
||||
@ -468,6 +472,7 @@ seed,
|
||||
)]
|
||||
async fn generate_stream(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(served_model_name): Extension<String>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> (
|
||||
@ -476,7 +481,7 @@ async fn generate_stream(
|
||||
) {
|
||||
let span = tracing::Span::current();
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
||||
generate_stream_internal(infer, served_model_name, compute_type, Json(req), span).await;
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
@ -495,6 +500,7 @@ async fn generate_stream(
|
||||
|
||||
async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
served_model_name: String,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
span: tracing::Span,
|
||||
@ -503,7 +509,7 @@ async fn generate_stream_internal(
|
||||
impl Stream<Item = Result<StreamResponse, InferError>>,
|
||||
) {
|
||||
let start_time = Instant::now();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
|
||||
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
|
||||
@ -540,7 +546,7 @@ async fn generate_stream_internal(
|
||||
tracing::error!("{err}");
|
||||
yield Err(err);
|
||||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
match infer.generate_stream(req, served_model_name.clone()).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, input_length, response_stream)) => {
|
||||
let mut index = 0;
|
||||
@ -605,13 +611,13 @@ async fn generate_stream_internal(
|
||||
span.record("seed", format!("{:?}", generated_text.seed));
|
||||
|
||||
// Metrics
|
||||
metrics::counter!("tgi_request_success").increment(1);
|
||||
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64);
|
||||
metrics::counter!("tgi_request_success", "model_name" => served_model_name.clone()).increment(1);
|
||||
metrics::histogram!("tgi_request_duration", "model_name" => served_model_name.clone()).record(total_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_validation_duration", "model_name" => served_model_name.clone()).record(validation_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_queue_duration", "model_name" => served_model_name.clone()).record(queue_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_inference_duration", "model_name" => served_model_name.clone()).record(inference_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", "model_name" => served_model_name.clone()).record(time_per_token.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_generated_tokens", "model_name" => served_model_name.clone()).record(generated_text.generated_tokens as f64);
|
||||
|
||||
// StreamResponse
|
||||
end_reached = true;
|
||||
@ -704,10 +710,11 @@ pub(crate) async fn completions(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Extension(served_model_name): Extension<String>,
|
||||
Json(req): Json<CompletionRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
|
||||
|
||||
let CompletionRequest {
|
||||
model,
|
||||
@ -798,6 +805,7 @@ pub(crate) async fn completions(
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
let served_model_name_clone = served_model_name.clone();
|
||||
|
||||
// Create a future for each generate_stream_internal call.
|
||||
let generate_future = async move {
|
||||
@ -807,6 +815,7 @@ pub(crate) async fn completions(
|
||||
tokio::spawn(async move {
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer_clone.clone(),
|
||||
served_model_name_clone.clone(),
|
||||
compute_type_clone.clone(),
|
||||
Json(generate_request),
|
||||
span_clone.clone(),
|
||||
@ -975,11 +984,13 @@ pub(crate) async fn completions(
|
||||
let responses = FuturesUnordered::new();
|
||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||
let infer_clone = infer.clone();
|
||||
let served_model_name_clone = served_model_name.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
let response_future = async move {
|
||||
let result = generate_internal(
|
||||
Extension(infer_clone),
|
||||
Extension(served_model_name_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
@ -1230,10 +1241,11 @@ pub(crate) async fn chat_completions(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Extension(served_model_name): Extension<String>,
|
||||
Json(chat): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
|
||||
let ChatRequest {
|
||||
model,
|
||||
stream,
|
||||
@ -1255,7 +1267,7 @@ pub(crate) async fn chat_completions(
|
||||
// switch on stream
|
||||
if stream {
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
generate_stream_internal(infer, served_model_name, compute_type, Json(generate_request), span).await;
|
||||
|
||||
// regex to match any function name
|
||||
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||
@ -1389,7 +1401,7 @@ pub(crate) async fn chat_completions(
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, input_length, Json(generation)) =
|
||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||
generate_internal(Extension(infer), Extension(served_model_name), compute_type, Json(generate_request), span).await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
@ -1688,6 +1700,7 @@ pub async fn run(
|
||||
max_client_batch_size: usize,
|
||||
usage_stats_level: usage_stats::UsageStatsLevel,
|
||||
payload_limit: usize,
|
||||
served_model_name: String,
|
||||
) -> Result<(), WebServerError> {
|
||||
// CORS allowed origins
|
||||
// map to go inside the option and then map to parse from String to HeaderValue
|
||||
@ -1963,6 +1976,7 @@ pub async fn run(
|
||||
compat_return_full_text,
|
||||
allow_origin,
|
||||
payload_limit,
|
||||
served_model_name,
|
||||
)
|
||||
.await;
|
||||
|
||||
@ -2024,6 +2038,7 @@ async fn start(
|
||||
compat_return_full_text: bool,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
payload_limit: usize,
|
||||
served_model_name: String,
|
||||
) -> Result<(), WebServerError> {
|
||||
// Determine the server port based on the feature and environment variable.
|
||||
let port = if cfg!(feature = "google") {
|
||||
@ -2076,22 +2091,22 @@ async fn start(
|
||||
duration_buckets.push(value);
|
||||
}
|
||||
// Input Length buckets
|
||||
let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length"));
|
||||
let input_length_matcher = Matcher::Full(format!("tgi_request_input_length{{model_name=\"{}\"}}", served_model_name));
|
||||
let input_length_buckets: Vec<f64> = (0..100)
|
||||
.map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64)
|
||||
.collect();
|
||||
// Generated tokens buckets
|
||||
let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
|
||||
let generated_tokens_matcher = Matcher::Full(format!("tgi_request_generated_tokens{{model_name=\"{}\"}}", served_model_name));
|
||||
let generated_tokens_buckets: Vec<f64> = (0..100)
|
||||
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
|
||||
.collect();
|
||||
// Input Length buckets
|
||||
let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens"));
|
||||
let max_new_tokens_matcher = Matcher::Full(format!("tgi_request_max_new_tokens{{model_name=\"{}\"}}", served_model_name));
|
||||
let max_new_tokens_buckets: Vec<f64> = (0..100)
|
||||
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
|
||||
.collect();
|
||||
// Batch size buckets
|
||||
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
||||
let batch_size_matcher = Matcher::Full(format!("ttgi_batch_next_size{{model_name=\"{}\"}}", served_model_name));
|
||||
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
||||
// Speculated tokens buckets
|
||||
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||
@ -2334,7 +2349,8 @@ async fn start(
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/vertex", post(vertex_compatibility))
|
||||
.route("/invocations", post(sagemaker_compatibility))
|
||||
.route("/tokenize", post(tokenize));
|
||||
.route("/tokenize", post(tokenize))
|
||||
.layer(Extension(served_model_name));
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
let mut prefix = "Bearer ".to_string();
|
||||
|
@ -133,6 +133,7 @@ impl Validation {
|
||||
add_special_tokens: bool,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
served_model_name: String,
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
let (encoding, inputs) = self
|
||||
@ -186,7 +187,7 @@ impl Validation {
|
||||
let ids = encoding.get_ids();
|
||||
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
||||
|
||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||
metrics::histogram!("tgi_request_input_length", "model_name" => served_model_name.clone()).record(input_length as f64);
|
||||
Ok((
|
||||
inputs,
|
||||
Some(input_ids),
|
||||
@ -201,6 +202,7 @@ impl Validation {
|
||||
pub(crate) async fn validate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
served_model_name: String,
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
let GenerateParameters {
|
||||
best_of,
|
||||
@ -332,6 +334,7 @@ impl Validation {
|
||||
request.add_special_tokens,
|
||||
truncate,
|
||||
max_new_tokens,
|
||||
served_model_name.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@ -405,7 +408,7 @@ impl Validation {
|
||||
ignore_eos_token: false,
|
||||
};
|
||||
|
||||
metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64);
|
||||
metrics::histogram!("tgi_request_max_new_tokens", "model_name" => served_model_name.clone()).record(max_new_tokens as f64);
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
@ -953,10 +956,10 @@ mod tests {
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
);
|
||||
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let max_new_tokens = 10;
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens), served_model_name)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
@ -989,9 +992,10 @@ mod tests {
|
||||
disable_grammar_support,
|
||||
);
|
||||
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
let max_new_tokens = 10;
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens), served_model_name)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
@ -1022,6 +1026,7 @@ mod tests {
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
@ -1031,7 +1036,7 @@ mod tests {
|
||||
do_sample: false,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
}, served_model_name)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::BestOfSampling) => (),
|
||||
@ -1062,6 +1067,7 @@ mod tests {
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
@ -1071,7 +1077,7 @@ mod tests {
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
}, served_model_name.clone())
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::TopP) => (),
|
||||
@ -1087,7 +1093,7 @@ mod tests {
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
}, served_model_name.clone())
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
@ -1103,7 +1109,7 @@ mod tests {
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
}, served_model_name.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||
@ -1133,6 +1139,7 @@ mod tests {
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
);
|
||||
let served_model_name = "bigscience/blomm-560m".to_string();
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
@ -1142,7 +1149,7 @@ mod tests {
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
}, served_model_name.clone())
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::TopNTokens(4, 5)) => (),
|
||||
@ -1158,7 +1165,7 @@ mod tests {
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
}, served_model_name.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -1171,7 +1178,7 @@ mod tests {
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
}, served_model_name.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -1184,7 +1191,7 @@ mod tests {
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
}, served_model_name.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -69,11 +69,12 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||
)]
|
||||
pub(crate) async fn vertex_compatibility(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(served_model_name): Extension<String>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
metrics::counter!("tgi_request_count", "model_name" => served_model_name.clone()).increment(1);
|
||||
|
||||
// check that theres at least one instance
|
||||
if req.instances.is_empty() {
|
||||
@ -111,12 +112,14 @@ pub(crate) async fn vertex_compatibility(
|
||||
};
|
||||
|
||||
let infer_clone = infer.clone();
|
||||
let served_model_name_clone = served_model_name.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
|
||||
futures.push(async move {
|
||||
generate_internal(
|
||||
Extension(infer_clone),
|
||||
Extension(served_model_name_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
|
Loading…
Reference in New Issue
Block a user