diff --git a/README.md b/README.md index 25dbbd43..fb475b09 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \ You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses. ```bash -curl localhost:3000/v1/chat/completions \ +curl localhost:8080/v1/chat/completions \ -X POST \ -d '{ "model": "tgi", diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index e0ba46c7..35a14e9e 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use std::path::PathBuf; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackend; -use text_generation_router::server; +use text_generation_router::{server, usage_stats}; use tokenizers::{FromPretrainedParameters, Tokenizer}; /// App Configuration @@ -48,14 +48,14 @@ struct Args { otlp_service_name: String, #[clap(long, env)] cors_allow_origin: Option>, - #[clap(long, env, default_value_t = false)] - messages_api_enabled: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, #[clap(long, env)] auth_token: Option, #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] executor_worker: PathBuf, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, } #[tokio::main] @@ -83,10 +83,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { otlp_endpoint, otlp_service_name, cors_allow_origin, - messages_api_enabled, max_client_batch_size, auth_token, executor_worker, + usage_stats, } = args; // Launch Tokio runtime @@ -155,11 +155,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { false, None, None, - messages_api_enabled, true, max_client_batch_size, - false, - false, + usage_stats, ) .await?; Ok(()) diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index f53d898e..bc00666c 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -63,8 +63,6 @@ struct Args { #[clap(long, env)] ngrok_edge: Option, #[clap(long, env, default_value_t = false)] - messages_api_enabled: bool, - #[clap(long, env, default_value_t = false)] disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, @@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, - messages_api_enabled, disable_grammar_support, max_client_batch_size, usage_stats, @@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, - messages_api_enabled, disable_grammar_support, max_client_batch_size, usage_stats, diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index b4751bd5..769168c0 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -63,8 +63,6 @@ struct Args { #[clap(long, env)] ngrok_edge: Option, #[clap(long, env, default_value_t = false)] - messages_api_enabled: bool, - #[clap(long, env, default_value_t = false)] disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, @@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, - messages_api_enabled, disable_grammar_support, max_client_batch_size, usage_stats, @@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, - messages_api_enabled, disable_grammar_support, max_client_batch_size, usage_stats, diff --git a/docs/openapi.json b/docs/openapi.json index d1b60f4d..e7da2d40 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -316,6 +316,98 @@ } } }, + "/invocations": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens from Sagemaker request", + "operationId": "sagemaker_compatibility", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SagemakerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Chat Completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SagemakerResponse" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/SagemakerStreamResponse" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error", + "error_type": "validation" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation", + "error_type": "generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded", + "error_type": "overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation", + "error_type": "incomplete_generation" + } + } + } + } + } + } + }, "/metrics": { "get": { "tags": [ @@ -1865,6 +1957,45 @@ "type": "string" } }, + "SagemakerRequest": { + "oneOf": [ + { + "$ref": "#/components/schemas/CompatGenerateRequest" + }, + { + "$ref": "#/components/schemas/ChatRequest" + }, + { + "$ref": "#/components/schemas/CompletionRequest" + } + ] + }, + "SagemakerResponse": { + "oneOf": [ + { + "$ref": "#/components/schemas/GenerateResponse" + }, + { + "$ref": "#/components/schemas/ChatCompletion" + }, + { + "$ref": "#/components/schemas/CompletionFinal" + } + ] + }, + "SagemakerStreamResponse": { + "oneOf": [ + { + "$ref": "#/components/schemas/StreamResponse" + }, + { + "$ref": "#/components/schemas/ChatCompletionChunk" + }, + { + "$ref": "#/components/schemas/Chunk" + } + ] + }, "SimpleToken": { "type": "object", "required": [ diff --git a/docs/source/reference/api_reference.md b/docs/source/reference/api_reference.md index 52043c80..45d951bb 100644 --- a/docs/source/reference/api_reference.md +++ b/docs/source/reference/api_reference.md @@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene ## Amazon SageMaker -To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`. - -This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API. +Amazon Sagemaker natively supports the message API: ```python import json @@ -161,12 +159,11 @@ except ValueError: hub = { 'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta', 'SM_NUM_GPUS': json.dumps(1), - 'MESSAGES_API_ENABLED': True } # create Hugging Face Model Class huggingface_model = HuggingFaceModel( - image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"), + image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"), env=hub, role=role, ) diff --git a/docs/source/usage_statistics.md b/docs/source/usage_statistics.md index a2c406ec..d3878b53 100644 --- a/docs/source/usage_statistics.md +++ b/docs/source/usage_statistics.md @@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected: "max_top_n_tokens": 5, "max_total_tokens": 2048, "max_waiting_tokens": 20, - "messages_api_enabled": false, "model_config": { "model_type": "Bloom" }, diff --git a/router/src/lib.rs b/router/src/lib.rs index fdbd931e..7c40c7e3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -8,6 +8,7 @@ pub mod validation; mod kserve; pub mod logging; +mod sagemaker; pub mod usage_stats; mod vertex; diff --git a/router/src/main.rs.back b/router/src/main.rs.back deleted file mode 100644 index 36879aa4..00000000 --- a/router/src/main.rs.back +++ /dev/null @@ -1,748 +0,0 @@ -use axum::http::HeaderValue; -use clap::Parser; -use clap::Subcommand; -use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; -use hf_hub::{Cache, Repo, RepoType}; -use opentelemetry::sdk::propagation::TraceContextPropagator; -use opentelemetry::sdk::trace; -use opentelemetry::sdk::trace::Sampler; -use opentelemetry::sdk::Resource; -use opentelemetry::{global, KeyValue}; -use opentelemetry_otlp::WithExportConfig; -use std::fs::File; -use std::io::BufReader; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::path::{Path, PathBuf}; -use text_generation_router::config::Config; -use text_generation_router::usage_stats; -use text_generation_router::{ - server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, -}; -use thiserror::Error; -use tokenizers::{processors::template::TemplateProcessing, Tokenizer}; -use tower_http::cors::AllowOrigin; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; - -/// App Configuration -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -struct Args { - #[command(subcommand)] - command: Option, - - #[clap(default_value = "128", long, env)] - max_concurrent_requests: usize, - #[clap(default_value = "2", long, env)] - max_best_of: usize, - #[clap(default_value = "4", long, env)] - max_stop_sequences: usize, - #[clap(default_value = "5", long, env)] - max_top_n_tokens: u32, - #[clap(default_value = "1024", long, env)] - max_input_tokens: usize, - #[clap(default_value = "2048", long, env)] - max_total_tokens: usize, - #[clap(default_value = "1.2", long, env)] - waiting_served_ratio: f32, - #[clap(default_value = "4096", long, env)] - max_batch_prefill_tokens: u32, - #[clap(long, env)] - max_batch_total_tokens: Option, - #[clap(default_value = "20", long, env)] - max_waiting_tokens: usize, - #[clap(long, env)] - max_batch_size: Option, - #[clap(default_value = "0.0.0.0", long, env)] - hostname: String, - #[clap(default_value = "3000", long, short, env)] - port: u16, - #[clap(default_value = "/tmp/text-generation-server-0", long, env)] - master_shard_uds_path: String, - #[clap(default_value = "bigscience/bloom", long, env)] - tokenizer_name: String, - #[clap(long, env)] - tokenizer_config_path: Option, - #[clap(long, env)] - revision: Option, - #[clap(default_value = "2", long, env)] - validation_workers: usize, - #[clap(long, env)] - json_output: bool, - #[clap(long, env)] - otlp_endpoint: Option, - #[clap(default_value = "text-generation-inference.router", long, env)] - otlp_service_name: String, - #[clap(long, env)] - cors_allow_origin: Option>, - #[clap(long, env)] - api_key: Option, - #[clap(long, env)] - ngrok: bool, - #[clap(long, env)] - ngrok_authtoken: Option, - #[clap(long, env)] - ngrok_edge: Option, - #[clap(long, env, default_value_t = false)] - messages_api_enabled: bool, - #[clap(long, env, default_value_t = false)] - disable_grammar_support: bool, - #[clap(default_value = "4", long, env)] - max_client_batch_size: usize, - #[clap(long, env, default_value_t)] - disable_usage_stats: bool, - #[clap(long, env, default_value_t)] - disable_crash_reports: bool, -} - -#[derive(Debug, Subcommand)] -enum Commands { - PrintSchema, -} - -#[tokio::main] -async fn main() -> Result<(), RouterError> { - let args = Args::parse(); - - // Pattern match configuration - let Args { - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - hostname, - port, - master_shard_uds_path, - tokenizer_name, - tokenizer_config_path, - revision, - validation_workers, - json_output, - otlp_endpoint, - otlp_service_name, - cors_allow_origin, - api_key, - ngrok, - ngrok_authtoken, - ngrok_edge, - messages_api_enabled, - disable_grammar_support, - max_client_batch_size, - disable_usage_stats, - disable_crash_reports, - command, - } = args; - - let print_schema_command = match command { - Some(Commands::PrintSchema) => true, - None => { - // only init logging if we are not running the print schema command - init_logging(otlp_endpoint, otlp_service_name, json_output); - false - } - }; - - // Validate args - if max_input_tokens >= max_total_tokens { - return Err(RouterError::ArgumentValidation( - "`max_input_tokens` must be < `max_total_tokens`".to_string(), - )); - } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); - } - - if validation_workers == 0 { - return Err(RouterError::ArgumentValidation( - "`validation_workers` must be > 0".to_string(), - )); - } - - if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } - } - - // CORS allowed origins - // map to go inside the option and then map to parse from String to HeaderValue - // Finally, convert to AllowOrigin - let cors_allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { - AllowOrigin::list( - cors_allow_origin - .iter() - .map(|origin| origin.parse::().unwrap()), - ) - }); - - // Parse Huggingface hub token - let authorization_token = std::env::var("HF_TOKEN") - .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) - .ok(); - - // Tokenizer instance - // This will only be used to validate payloads - let local_path = Path::new(&tokenizer_name); - - // Shared API builder initialization - let api_builder = || { - let mut builder = ApiBuilder::new() - .with_progress(false) - .with_token(authorization_token); - - if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { - builder = builder.with_cache_dir(cache_dir.into()); - } - - builder - }; - - // Decide if we need to use the API based on the revision and local path - let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); - - // Initialize API if needed - #[derive(Clone)] - enum Type { - Api(Api), - Cache(Cache), - None, - } - let api = if use_api { - if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { - let cache = std::env::var("HUGGINGFACE_HUB_CACHE") - .map_err(|_| ()) - .map(|cache_dir| Cache::new(cache_dir.into())) - .unwrap_or_else(|_| Cache::default()); - - tracing::warn!("Offline mode active using cache defaults"); - Type::Cache(cache) - } else { - tracing::info!("Using the Hugging Face API"); - match api_builder().build() { - Ok(api) => Type::Api(api), - Err(_) => { - tracing::warn!("Unable to build the Hugging Face API"); - Type::None - } - } - } - } else { - Type::None - }; - - // Load tokenizer and model info - let ( - tokenizer_filename, - config_filename, - tokenizer_config_filename, - preprocessor_config_filename, - processor_config_filename, - model_info, - ) = match api { - Type::None => ( - Some(local_path.join("tokenizer.json")), - Some(local_path.join("config.json")), - Some(local_path.join("tokenizer_config.json")), - Some(local_path.join("preprocessor_config.json")), - Some(local_path.join("processor_config.json")), - None, - ), - Type::Api(api) => { - let api_repo = api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.clone().unwrap_or_else(|| "main".to_string()), - )); - - let tokenizer_filename = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Some(tokenizer_filename), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; - let config_filename = api_repo.get("config.json").await.ok(); - let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); - let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); - let processor_config_filename = api_repo.get("processor_config.json").await.ok(); - - let model_info = if let Some(model_info) = get_model_info(&api_repo).await { - Some(model_info) - } else { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - None - }; - ( - tokenizer_filename, - config_filename, - tokenizer_config_filename, - preprocessor_config_filename, - processor_config_filename, - model_info, - ) - } - Type::Cache(cache) => { - let repo = cache.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.clone().unwrap_or_else(|| "main".to_string()), - )); - ( - repo.get("tokenizer.json"), - repo.get("config.json"), - repo.get("tokenizer_config.json"), - repo.get("preprocessor_config.json"), - repo.get("processor_config.json"), - None, - ) - } - }; - let config: Option = config_filename.and_then(|filename| { - std::fs::read_to_string(filename) - .ok() - .as_ref() - .and_then(|c| { - let config: Result = serde_json::from_str(c); - if let Err(err) = &config { - tracing::warn!("Could not parse config {err:?}"); - } - config.ok() - }) - }); - let model_info = model_info.unwrap_or_else(|| HubModelInfo { - model_id: tokenizer_name.to_string(), - sha: None, - pipeline_tag: None, - }); - - // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path - { - HubTokenizerConfig::from_file(filename) - } else { - tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) - }; - let tokenizer_config = tokenizer_config.unwrap_or_else(|| { - tracing::warn!("Could not find tokenizer config locally and no API specified"); - HubTokenizerConfig::default() - }); - let tokenizer_class = tokenizer_config.tokenizer_class.clone(); - - let tokenizer: Option = tokenizer_filename.and_then(|filename| { - let mut tokenizer = Tokenizer::from_file(filename).ok(); - if let Some(tokenizer) = &mut tokenizer { - if let Some(class) = &tokenizer_config.tokenizer_class { - if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ - if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { - tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); - tokenizer.with_post_processor(post_processor); - } - } - } - } - tokenizer - }); - - let preprocessor_config = - preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); - let processor_config = processor_config_filename - .and_then(HubProcessorConfig::from_file) - .unwrap_or_default(); - - tracing::info!("Using config {config:?}"); - if tokenizer.is_none() { - tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); - tracing::warn!("Rust input length validation and truncation is disabled"); - } - - // if pipeline-tag == text-generation we default to return_full_text = true - let compat_return_full_text = match &model_info.pipeline_tag { - None => { - tracing::warn!("no pipeline tag found for model {tokenizer_name}"); - true - } - Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", - }; - - // Determine the server port based on the feature and environment variable. - let port = if cfg!(feature = "google") { - std::env::var("AIP_HTTP_PORT") - .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) - .unwrap_or(port) - } else { - port - }; - - let addr = match hostname.parse() { - Ok(ip) => SocketAddr::new(ip, port), - Err(_) => { - tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) - } - }; - - // Only send usage stats when TGI is run in container and the function returns Some - let is_container = matches!(usage_stats::is_container(), Ok(true)); - - let user_agent = if !disable_usage_stats && is_container { - let reduced_args = usage_stats::Args::new( - config.clone(), - tokenizer_class, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - revision, - validation_workers, - messages_api_enabled, - disable_grammar_support, - max_client_batch_size, - disable_usage_stats, - disable_crash_reports, - ); - Some(usage_stats::UserAgent::new(reduced_args)) - } else { - None - }; - - if let Some(ref ua) = user_agent { - let start_event = - usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); - tokio::spawn(async move { - start_event.send().await; - }); - }; - - // Run server - let result = server::run( - master_shard_uds_path, - model_info, - compat_return_full_text, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - tokenizer, - config, - validation_workers, - addr, - cors_allow_origin, - api_key, - ngrok, - ngrok_authtoken, - ngrok_edge, - tokenizer_config, - preprocessor_config, - processor_config, - messages_api_enabled, - disable_grammar_support, - max_client_batch_size, - print_schema_command, - ) - .await; - - match result { - Ok(_) => { - if let Some(ref ua) = user_agent { - let stop_event = usage_stats::UsageStatsEvent::new( - ua.clone(), - usage_stats::EventType::Stop, - None, - ); - stop_event.send().await; - }; - Ok(()) - } - Err(e) => { - if let Some(ref ua) = user_agent { - if !disable_crash_reports { - let error_event = usage_stats::UsageStatsEvent::new( - ua.clone(), - usage_stats::EventType::Error, - Some(e.to_string()), - ); - error_event.send().await; - } else { - let unknow_error_event = usage_stats::UsageStatsEvent::new( - ua.clone(), - usage_stats::EventType::Error, - Some("unknow_error".to_string()), - ); - unknow_error_event.send().await; - } - }; - Err(RouterError::WebServer(e)) - } - } -} - -/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: -/// - otlp_endpoint is an optional URL to an Open Telemetry collector -/// - otlp_service_name service name to appear in APM -/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) -/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) -/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) -fn init_logging(otlp_endpoint: Option, otlp_service_name: String, json_output: bool) { - let mut layers = Vec::new(); - - // STDOUT/STDERR layer - let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); - let fmt_layer = tracing_subscriber::fmt::layer() - .with_file(true) - .with_ansi(ansi) - .with_line_number(true); - - let fmt_layer = match json_output { - true => fmt_layer.json().flatten_event(true).boxed(), - false => fmt_layer.boxed(), - }; - layers.push(fmt_layer); - - // OpenTelemetry tracing layer - if let Some(otlp_endpoint) = otlp_endpoint { - global::set_text_map_propagator(TraceContextPropagator::new()); - - let tracer = opentelemetry_otlp::new_pipeline() - .tracing() - .with_exporter( - opentelemetry_otlp::new_exporter() - .tonic() - .with_endpoint(otlp_endpoint), - ) - .with_trace_config( - trace::config() - .with_resource(Resource::new(vec![KeyValue::new( - "service.name", - otlp_service_name, - )])) - .with_sampler(Sampler::AlwaysOn), - ) - .install_batch(opentelemetry::runtime::Tokio); - - if let Ok(tracer) = tracer { - layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); - init_tracing_opentelemetry::init_propagator().unwrap(); - }; - } - - // Filter events with LOG_LEVEL - let varname = "LOG_LEVEL"; - let env_filter = if let Ok(log_level) = std::env::var(varname) { - // Override to avoid simple logs to be spammed with tokio level informations - let log_level = match &log_level[..] { - "warn" => "text_generation_launcher=warn,text_generation_router=warn", - "info" => "text_generation_launcher=info,text_generation_router=info", - "debug" => "text_generation_launcher=debug,text_generation_router=debug", - log_level => log_level, - }; - EnvFilter::builder() - .with_default_directive(LevelFilter::INFO.into()) - .parse_lossy(log_level) - } else { - EnvFilter::new("info") - }; - - tracing_subscriber::registry() - .with(env_filter) - .with(layers) - .init(); -} - -/// get model info from the Huggingface Hub -pub async fn get_model_info(api: &ApiRepo) -> Option { - let response = api.info_request().send().await.ok()?; - - if response.status().is_success() { - let hub_model_info: HubModelInfo = - serde_json::from_str(&response.text().await.ok()?).ok()?; - if let Some(sha) = &hub_model_info.sha { - tracing::info!( - "Serving revision {sha} of model {}", - hub_model_info.model_id - ); - } - Some(hub_model_info) - } else { - None - } -} - -/// get base tokenizer -pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { - let config_filename = api_repo.get("config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of `User`. - let config: serde_json::Value = serde_json::from_reader(reader).ok()?; - - if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { - let api_base_repo = api.repo(Repo::with_revision( - base_model_id.to_string(), - RepoType::Model, - "main".to_string(), - )); - - api_base_repo.get("tokenizer.json").await.ok() - } else { - None - } -} - -/// get tokenizer_config from the Huggingface Hub -pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { - let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(tokenizer_config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) - .map_err(|e| { - tracing::warn!("Unable to parse tokenizer config: {}", e); - e - }) - .ok()?; - - Some(tokenizer_config) -} - -/// Create a post_processor for the LlamaTokenizer -pub fn create_post_processor( - tokenizer: &Tokenizer, - tokenizer_config: &HubTokenizerConfig, -) -> Result { - let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true); - let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false); - - let bos_token = tokenizer_config.bos_token.as_ref(); - let eos_token = tokenizer_config.eos_token.as_ref(); - - if add_bos_token && bos_token.is_none() { - panic!("add_bos_token = true but bos_token is None"); - } - - if add_eos_token && eos_token.is_none() { - panic!("add_eos_token = true but eos_token is None"); - } - - let mut single = Vec::new(); - let mut pair = Vec::new(); - let mut special_tokens = Vec::new(); - - if add_bos_token { - if let Some(bos) = bos_token { - let bos_token_id = tokenizer - .token_to_id(bos.as_str()) - .expect("Should have found the bos token id"); - special_tokens.push((bos.as_str(), bos_token_id)); - single.push(format!("{}:0", bos.as_str())); - pair.push(format!("{}:0", bos.as_str())); - } - } - - single.push("$A:0".to_string()); - pair.push("$A:0".to_string()); - - if add_eos_token { - if let Some(eos) = eos_token { - let eos_token_id = tokenizer - .token_to_id(eos.as_str()) - .expect("Should have found the eos token id"); - special_tokens.push((eos.as_str(), eos_token_id)); - single.push(format!("{}:0", eos.as_str())); - pair.push(format!("{}:0", eos.as_str())); - } - } - - if add_bos_token { - if let Some(bos) = bos_token { - pair.push(format!("{}:1", bos.as_str())); - } - } - - pair.push("$B:1".to_string()); - - if add_eos_token { - if let Some(eos) = eos_token { - pair.push(format!("{}:1", eos.as_str())); - } - } - - let post_processor = TemplateProcessing::builder() - .try_single(single)? - .try_pair(pair)? - .special_tokens(special_tokens) - .build()?; - - Ok(post_processor) -} - -#[derive(Debug, Error)] -enum RouterError { - #[error("Argument validation error: {0}")] - ArgumentValidation(String), - #[error("WebServer error: {0}")] - WebServer(#[from] server::WebServerError), - #[error("Tokio runtime failed to start: {0}")] - Tokio(#[from] std::io::Error), -} - -#[cfg(test)] -mod tests { - use super::*; - use text_generation_router::TokenizerConfigToken; - - #[test] - fn test_create_post_processor() { - let tokenizer_config = HubTokenizerConfig { - add_bos_token: None, - add_eos_token: None, - bos_token: Some(TokenizerConfigToken::String("".to_string())), - eos_token: Some(TokenizerConfigToken::String("".to_string())), - chat_template: None, - tokenizer_class: None, - completion_template: None, - }; - - let tokenizer = - Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap(); - let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); - - let expected = TemplateProcessing::builder() - .try_single(":0 $A:0") - .unwrap() - .try_pair(":0 $A:0 :1 $B:1") - .unwrap() - .special_tokens(vec![("".to_string(), 1)]) - .build() - .unwrap(); - - assert_eq!(post_processor, expected); - } -} diff --git a/router/src/sagemaker.rs b/router/src/sagemaker.rs new file mode 100644 index 00000000..750ef222 --- /dev/null +++ b/router/src/sagemaker.rs @@ -0,0 +1,82 @@ +use crate::infer::Infer; +use crate::server::{chat_completions, compat_generate, completions, ComputeType}; +use crate::{ + ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest, + CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse, +}; +use axum::extract::Extension; +use axum::http::StatusCode; +use axum::response::Response; +use axum::Json; +use serde::{Deserialize, Serialize}; +use tracing::instrument; +use utoipa::ToSchema; + +#[derive(Clone, Deserialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum SagemakerRequest { + Generate(CompatGenerateRequest), + Chat(ChatRequest), + Completion(CompletionRequest), +} + +// Used for OpenAPI specs +#[allow(dead_code)] +#[derive(Serialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum SagemakerResponse { + Generate(GenerateResponse), + Chat(ChatCompletion), + Completion(CompletionFinal), +} + +// Used for OpenAPI specs +#[allow(dead_code)] +#[derive(Serialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum SagemakerStreamResponse { + Generate(StreamResponse), + Chat(ChatCompletionChunk), + Completion(Chunk), +} + +/// Generate tokens from Sagemaker request +#[utoipa::path( +post, +tag = "Text Generation Inference", +path = "/invocations", +request_body = SagemakerRequest, +responses( +(status = 200, description = "Generated Chat Completion", +content( +("application/json" = SagemakerResponse), +("text/event-stream" = SagemakerStreamResponse), +)), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation", "error_type": "generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error", "error_type": "validation"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})), +) +)] +#[instrument(skip_all)] +pub(crate) async fn sagemaker_compatibility( + default_return_full_text: Extension, + infer: Extension, + compute_type: Extension, + info: Extension, + Json(req): Json, +) -> Result)> { + match req { + SagemakerRequest::Generate(req) => { + compat_generate(default_return_full_text, infer, compute_type, Json(req)).await + } + SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await, + SagemakerRequest::Completion(req) => { + completions(infer, compute_type, info, Json(req)).await + } + } +} diff --git a/router/src/server.rs b/router/src/server.rs index 5e6e6960..5abca058 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,6 +7,10 @@ use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; +use crate::sagemaker::{ + sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse, + __path_sagemaker_compatibility, +}; use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::ChatTokenizeResponse; @@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})), ) )] #[instrument(skip(infer, req))] -async fn compat_generate( +pub(crate) async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, compute_type: Extension, @@ -678,7 +682,7 @@ time_per_token, seed, ) )] -async fn completions( +pub(crate) async fn completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, @@ -1202,7 +1206,7 @@ time_per_token, seed, ) )] -async fn chat_completions( +pub(crate) async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, @@ -1513,11 +1517,13 @@ completions, tokenize, metrics, openai_get_model_info, +sagemaker_compatibility, ), components( schemas( Info, CompatGenerateRequest, +SagemakerRequest, GenerateRequest, GrammarType, ChatRequest, @@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob, ChatCompletion, CompletionRequest, CompletionComplete, +SagemakerResponse, +SagemakerStreamResponse, Chunk, Completion, CompletionFinal, @@ -1607,7 +1615,6 @@ pub async fn run( ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, - messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, usage_stats_level: usage_stats::UsageStatsLevel, @@ -1836,7 +1843,6 @@ pub async fn run( // max_batch_size, revision.clone(), validation_workers, - messages_api_enabled, disable_grammar_support, max_client_batch_size, usage_stats_level, @@ -1878,7 +1884,6 @@ pub async fn run( ngrok, _ngrok_authtoken, _ngrok_edge, - messages_api_enabled, disable_grammar_support, max_client_batch_size, model_info, @@ -1938,7 +1943,6 @@ async fn start( ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, - messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, model_info: HubModelInfo, @@ -2253,6 +2257,7 @@ async fn start( .route("/v1/chat/completions", post(chat_completions)) .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) + .route("/invocations", post(sagemaker_compatibility)) .route("/tokenize", post(tokenize)); if let Some(api_key) = api_key { @@ -2288,13 +2293,6 @@ async fn start( .route("/metrics", get(metrics)) .route("/v1/models", get(openai_get_model_info)); - // Conditional AWS Sagemaker route - let aws_sagemaker_route = if messages_api_enabled { - Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED - } else { - Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise - }; - let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); @@ -2302,8 +2300,7 @@ async fn start( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) - .merge(info_routes) - .merge(aws_sagemaker_route); + .merge(info_routes); #[cfg(feature = "google")] { diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index 0282ac63..e9d98327 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -93,7 +93,6 @@ pub struct Args { // max_batch_size: Option, revision: Option, validation_workers: usize, - messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, usage_stats_level: UsageStatsLevel, @@ -117,7 +116,6 @@ impl Args { // max_batch_size: Option, revision: Option, validation_workers: usize, - messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, usage_stats_level: UsageStatsLevel, @@ -138,7 +136,6 @@ impl Args { // max_batch_size, revision, validation_workers, - messages_api_enabled, disable_grammar_support, max_client_batch_size, usage_stats_level, diff --git a/update_doc.py b/update_doc.py index 203aaced..6357cc00 100644 --- a/update_doc.py +++ b/update_doc.py @@ -172,6 +172,8 @@ def check_openapi(check: bool): # allow for trailing whitespace since it's not significant # and the precommit hook will remove it "lint", + "--skip-rule", + "security-defined", filename, ], capture_output=True,