diff --git a/Cargo.lock b/Cargo.lock index 27c345b3..91e1b4c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -773,9 +773,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -783,9 +783,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" @@ -800,15 +800,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", @@ -817,21 +817,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-util" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -1369,6 +1369,15 @@ dependencies = [ "unicase", ] +[[package]] +name = "minijinja" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "208758577ef2c86cf5dd3e85730d161413ec3284e2d73b2ef65d9a24d9971bcb" +dependencies = [ + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2803,10 +2812,12 @@ dependencies = [ "axum-tracing-opentelemetry", "clap", "futures", + "futures-util", "hf-hub", "init-tracing-opentelemetry", "metrics", "metrics-exporter-prometheus", + "minijinja", "ngrok", "nohash-hasher", "opentelemetry", diff --git a/router/Cargo.toml b/router/Cargo.toml index 55af635a..134782ff 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -43,6 +43,8 @@ utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } hf-hub = "0.3.1" init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +minijinja = "1.0.10" +futures-util = "0.3.30" [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/src/infer.rs b/router/src/infer.rs index bf5920da..3ce4923c 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,7 +1,8 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; +use crate::HubTokenizerConfig; +use crate::{ChatRequest, GenerateRequest, PrefillToken}; use crate::{Entry, Queue, Token}; -use crate::{GenerateRequest, PrefillToken}; use futures::future::try_join_all; use nohash_hasher::IntMap; use std::sync::{ @@ -26,6 +27,8 @@ pub struct Infer { validation: Validation, /// Request queue queue: Queue, + /// Chat formatter + tokenizer_config: HubTokenizerConfig, /// Shared state shared: Arc, /// Inference limit @@ -52,6 +55,7 @@ impl Infer { window_size: Option, speculate: u32, generation_health: Arc, + tokenizer_config: HubTokenizerConfig, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding, 16, window_size, speculate); @@ -79,6 +83,7 @@ impl Infer { queue, shared, limit_concurrent_requests: semaphore, + tokenizer_config, } } @@ -133,6 +138,28 @@ impl Infer { Ok((permit, UnboundedReceiverStream::new(response_rx))) } + /// Apply the chat template to the chat request + #[instrument(skip_all)] + pub(crate) fn apply_chat_template( + &self, + chat: ChatRequest, + ) -> Result { + let mut env = minijinja::Environment::new(); + let chat_template = self + .tokenizer_config + .chat_template + .as_ref() + .ok_or(ChatTemplateError::TemplateNotFound)?; + env.add_template("_", chat_template) + .map_err(|e| ChatTemplateError::TemplateError(e))?; + let jinja_tmpl = env + .get_template("_") + .map_err(|e| ChatTemplateError::TemplateError(e))?; + jinja_tmpl + .render(chat) + .map_err(|e| ChatTemplateError::TemplateError(e)) + } + /// Add a new request to the queue and return a InferResponse #[instrument(skip_all)] pub(crate) async fn generate( @@ -666,3 +693,20 @@ impl InferError { } } } + +#[derive(Debug, Error)] +pub enum ChatTemplateError { + #[error("Template error: {0}")] + TemplateError(#[from] minijinja::Error), + #[error("Template not found")] + TemplateNotFound, +} + +impl ChatTemplateError { + pub(crate) fn error_type(&self) -> &str { + match self { + ChatTemplateError::TemplateError(_) => "template_error", + ChatTemplateError::TemplateNotFound => "template_not_found", + } + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 898fcd04..411df519 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -20,6 +20,12 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } +#[derive(Clone, Deserialize)] +pub struct HubTokenizerConfig { + #[serde(default)] + pub chat_template: Option, +} + #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -165,6 +171,182 @@ fn default_parameters() -> GenerateParameters { } } +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletion { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionComplete { + pub index: u32, + pub message: Message, + pub logprobs: Option>, + pub finish_reason: Option, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +impl ChatCompletion { + pub(crate) fn new( + ouput: String, + created: u64, + details: Details, + prompt_character_count: u32, + ) -> Self { + Self { + id: "".to_string(), + object: "text_completion".to_string(), + created, + model: "".to_string(), + system_fingerprint: "".to_string(), + choices: vec![ChatCompletionComplete { + index: 0, + message: Message { + role: "assistant".to_string(), + content: ouput, + }, + logprobs: None, + finish_reason: None, + }], + usage: Usage { + prompt_tokens: prompt_character_count, + completion_tokens: details.generated_tokens, + total_tokens: prompt_character_count + details.generated_tokens, + }, + } + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionChunk { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub choices: Vec, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionChoice { + pub index: u32, + pub delta: ChatCompletionDelta, + pub logprobs: Option>, + pub finish_reason: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub(crate) struct ChatCompletionDelta { + pub role: String, + pub content: String, +} + +impl ChatCompletionChunk { + pub(crate) fn new(delta: String, created: u64, index: u32) -> Self { + Self { + id: "".to_string(), + object: "text_completion".to_string(), + created, + model: "".to_string(), + system_fingerprint: "".to_string(), + choices: vec![ChatCompletionChoice { + index, + delta: ChatCompletionDelta { + role: "assistant".to_string(), + content: delta, + }, + logprobs: None, + finish_reason: None, + }], + } + } +} + +fn default_request_messages() -> Vec { + vec![Message { + role: "system".to_string(), + content: "My name is David and I".to_string(), + }] +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct ChatRequest { + /// UNUSED + #[schema(example = "bigscience/blomm-560m")] + /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. + pub model: String, /* NOTE: UNUSED */ + + /// A list of messages comprising the conversation so far. + #[serde(default = "default_request_messages")] + pub messages: Vec, + + /// UNUSED + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, + /// decreasing the model's likelihood to repeat the same line verbatim. + #[serde(default)] + pub frequency_penalty: Option, + + /// UNUSED + /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens + /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, + /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, + /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should + /// result in a ban or exclusive selection of the relevant token. + #[serde(default)] + pub logit_bias: Option>, + + /// UNUSED + /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each + /// output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview + /// model. + #[serde(default)] + pub logprobs: Option, + + /// UNUSED + /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with + /// an associated log probability. logprobs must be set to true if this parameter is used. + #[serde(default)] + pub top_logprobs: Option, + + /// The maximum number of tokens that can be generated in the chat completion. + #[serde(default)] + pub max_tokens: Option, + + /// UNUSED + /// How many chat completion choices to generate for each input message. Note that you will be charged based on the + /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + #[serde(default)] + pub n: Option, + + /// UNUSED + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, + /// increasing the model's likelihood to talk about new topics + #[serde(default)] + pub presence_penalty: Option, + + #[serde(default = "bool::default")] + pub stream: bool, +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct Message { + #[schema(example = "system")] + pub role: String, + #[schema(example = "My name is David and I")] + pub content: String, +} + #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateRequest { #[schema(example = "My name is Olivier and I")] diff --git a/router/src/main.rs b/router/src/main.rs index d90632ef..875070d1 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -11,7 +11,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use std::time::Duration; use text_generation_client::{ClientError, ShardedClient}; -use text_generation_router::{server, HubModelInfo}; +use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use thiserror::Error; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tower_http::cors::AllowOrigin; @@ -176,7 +176,7 @@ fn main() -> Result<(), RouterError> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, revision, authorization_token) + false => get_model_info(&tokenizer_name, revision.as_deref(), authorization_token.as_deref()) .await .unwrap_or_else(|| { tracing::warn!("Could not retrieve model info from the Hugging Face hub."); @@ -188,6 +188,20 @@ fn main() -> Result<(), RouterError> { }), }; + let tokenizer_config: HubTokenizerConfig = match local_model { + true => HubTokenizerConfig{ + chat_template: None, + }, + false => get_tokenizer_config(&tokenizer_name, revision.as_deref(), authorization_token.as_deref()) + .await.unwrap_or_else(|| { + tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); + HubTokenizerConfig{ + chat_template: None, + } + }), + }; + + // if pipeline-tag == text-generation we default to return_full_text = true let compat_return_full_text = match &model_info.pipeline_tag { None => { @@ -277,6 +291,7 @@ fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, + tokenizer_config, ) .await?; Ok(()) @@ -341,8 +356,8 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { /// get model info from the Huggingface Hub pub async fn get_model_info( model_id: &str, - revision: Option, - token: Option, + revision: Option<&str>, + token: Option<&str>, ) -> Option { let revision = match revision { None => { @@ -350,7 +365,7 @@ pub async fn get_model_info( tracing::warn!("We strongly advise to set it to a known supported commit."); "main".to_string() } - Some(revision) => revision, + Some(revision) => revision.to_string(), }; let client = reqwest::Client::new(); @@ -379,6 +394,41 @@ pub async fn get_model_info( } } +/// get tokenizer_config from the Huggingface Hub +pub async fn get_tokenizer_config( + model_id: &str, + revision: Option<&str>, + token: Option<&str>, +) -> Option { + let revision = match revision { + None => { + tracing::warn!("`--revision` is not set"); + tracing::warn!("We strongly advise to set it to a known supported commit."); + "main".to_string() + } + Some(revision) => revision.to_string(), + }; + let client = reqwest::Client::new(); + // Poor man's urlencode + let revision = revision.replace('/', "%2F"); + let url = format!( + "https://huggingface.co/{}/raw/{}/tokenizer_config.json", + model_id, revision + ); + let mut builder = client.get(url).timeout(Duration::from_secs(5)); + if let Some(token) = token { + builder = builder.bearer_auth(token); + } + let response = builder.send().await.ok()?; + if response.status().is_success() { + let text = response.text().await.ok()?; + let hub_tokenizer_config: HubTokenizerConfig = serde_json::from_str(&text).ok()?; + Some(hub_tokenizer_config) + } else { + None + } +} + #[derive(Debug, Error)] enum RouterError { #[error("Argument validation error: {0}")] diff --git a/router/src/server.rs b/router/src/server.rs index fe1b8309..d4521d96 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2,10 +2,11 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; +use crate::HubTokenizerConfig; use crate::{ - BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, - StreamDetails, StreamResponse, Token, Validation, + BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, + Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, + HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -337,6 +338,22 @@ async fn generate_stream( HeaderMap, Sse>>, ) { + let on_message_callback = |stream_token: StreamResponse| { + let event = Event::default(); + event.json_data(stream_token).unwrap() + }; + let (headers, response_stream) = + generate_stream_internal(infer, Json(req.into()), on_message_callback).await; + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + (headers, sse) +} + +/// +async fn generate_stream_internal( + infer: Infer, + Json(req): Json, + on_message_callback: impl Fn(StreamResponse) -> Event, +) -> (HeaderMap, impl Stream>) { let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -401,7 +418,8 @@ async fn generate_stream( details: None, }; - yield Ok(Event::default().json_data(stream_token).unwrap()) + let event = on_message_callback(stream_token); + yield Ok(event); } // Yield event for last token and compute timings InferStreamResponse::End { @@ -463,7 +481,9 @@ async fn generate_stream( details }; - yield Ok(Event::default().json_data(stream_token).unwrap()); + + let event = on_message_callback(stream_token); + yield Ok(event); break; } } @@ -494,7 +514,141 @@ async fn generate_stream( } }; - (headers, Sse::new(stream).keep_alive(KeepAlive::default())) + (headers, stream) +} + +/// Generate tokens +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v1/chat/completions", + request_body = ChatRequest, + responses( + (status = 200, description = "Generated Text", body = GenerateResponse), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument( + skip_all, + fields( + // parameters = ? req.parameters, + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) + )] +async fn chat( + Extension(infer): Extension, + Json(req): Json, +) -> Result)> { + metrics::increment_counter!("tgi_request_count"); + + // extract the values we need for the chat request + let stream = req.stream; + let max_new_tokens = match req.max_tokens { + Some(max_new_tokens) => Some(max_new_tokens), + None => Some(100) + }; + + // apply chat template to flatten the request into a single input + let inputs = match infer.apply_chat_template(req) { + Ok(inputs) => inputs, + Err(err) => { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: err.to_string(), + error_type: err.error_type().to_string(), + }), + )); + } + }; + + // poor man's token count (assumes that each character is a token) + let prompt_character_count: u32 = inputs.chars().count().try_into().unwrap_or_default(); + + // build the request passing some parameters + let generate_request = GenerateRequest { + inputs: inputs.to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: None, + repetition_penalty: None, + top_k: None, + top_p: None, + typical_p: None, + do_sample: false, + max_new_tokens, + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: true, + decoder_input_details: false, + seed: None, + top_n_tokens: None, + }, + }; + + // switch on stream + if stream { + // pass this callback to the stream generation and build the required event structure + let on_message_callback = |stream_token: StreamResponse| { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + event + .json_data(ChatCompletionChunk::new( + stream_token.token.text, + current_time, + 0, + )) + .unwrap_or_else(|_| { + println!("Failed to serialize ChatCompletionChunk"); + Event::default() + }) + }; + + let (headers, response_stream) = + generate_stream_internal(infer, Json(generate_request), on_message_callback).await; + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + Ok((headers, sse).into_response()) + } else { + let (headers, Json(generation)) = + generate(Extension(infer), Json(generate_request)).await?; + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + // build the complete response object with the full text + let response = ChatCompletion::new( + generation.generated_text, + current_time, + generation.details.unwrap(), + prompt_character_count, + ); + + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(response)).into_response()) + } } /// Prometheus metrics scrape endpoint @@ -532,6 +686,7 @@ pub async fn run( ngrok: bool, ngrok_authtoken: Option, ngrok_edge: Option, + tokenizer_config: HubTokenizerConfig, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -598,6 +753,7 @@ pub async fn run( shard_info.window_size, shard_info.speculate, generation_health, + tokenizer_config, ); // Duration buckets @@ -687,6 +843,7 @@ pub async fn run( .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) + .route("/v1/chat/completions", post(chat)) // AWS Sagemaker route .route("/invocations", post(compat_generate)) // Base Health route