feat: supports openai chat completions API

prefer PR from original repo rather than fork to run CI https://github.com/huggingface/text-generation-inference/pull/1408
This commit is contained in:
drbh 2024-01-10 10:08:51 -05:00
parent ac08b4ef9c
commit 9fdf47f766
7 changed files with 665 additions and 213 deletions

39
Cargo.lock generated
View File

@ -773,9 +773,9 @@ dependencies = [
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.29" version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
@ -783,9 +783,9 @@ dependencies = [
[[package]] [[package]]
name = "futures-core" name = "futures-core"
version = "0.3.29" version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
[[package]] [[package]]
name = "futures-executor" name = "futures-executor"
@ -800,15 +800,15 @@ dependencies = [
[[package]] [[package]]
name = "futures-io" name = "futures-io"
version = "0.3.29" version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
[[package]] [[package]]
name = "futures-macro" name = "futures-macro"
version = "0.3.29" version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -817,21 +817,21 @@ dependencies = [
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.29" version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
[[package]] [[package]]
name = "futures-task" name = "futures-task"
version = "0.3.29" version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.29" version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
@ -1373,6 +1373,15 @@ dependencies = [
"unicase", "unicase",
] ]
[[package]]
name = "minijinja"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "208758577ef2c86cf5dd3e85730d161413ec3284e2d73b2ef65d9a24d9971bcb"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
version = "0.2.1" version = "0.2.1"
@ -2807,10 +2816,12 @@ dependencies = [
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"clap", "clap",
"futures", "futures",
"futures-util",
"hf-hub", "hf-hub",
"init-tracing-opentelemetry", "init-tracing-opentelemetry",
"metrics", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"minijinja",
"ngrok", "ngrok",
"nohash-hasher", "nohash-hasher",
"opentelemetry", "opentelemetry",

View File

@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = "1.0.10"
futures-util = "0.3.30"
[build-dependencies] [build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

View File

@ -1,7 +1,8 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all; use futures::future::try_join_all;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::{ use std::sync::{
@ -13,7 +14,7 @@ use text_generation_client::{
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@ -26,6 +27,8 @@ pub struct Infer {
validation: Validation, validation: Validation,
/// Request queue /// Request queue
queue: Queue, queue: Queue,
/// Chat formatter
tokenizer_config: HubTokenizerConfig,
/// Shared state /// Shared state
shared: Arc<Shared>, shared: Arc<Shared>,
/// Inference limit /// Inference limit
@ -52,6 +55,7 @@ impl Infer {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
tokenizer_config: HubTokenizerConfig,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate); let queue = Queue::new(requires_padding, 16, window_size, speculate);
@ -79,6 +83,7 @@ impl Infer {
queue, queue,
shared, shared,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
tokenizer_config,
} }
} }
@ -87,14 +92,7 @@ impl Infer {
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result< ) -> Result<GenerateStreamResponse, InferError> {
(
OwnedSemaphorePermit,
u32,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
let permit = self let permit = self
.clone() .clone()
@ -119,7 +117,7 @@ impl Infer {
// Append the request to the queue // Append the request to the queue
self.queue.append(Entry { self.queue.append(Entry {
request: valid_request, request: valid_request.clone(),
response_tx, response_tx,
span: Span::current(), span: Span::current(),
temp_span: None, temp_span: None,
@ -139,6 +137,14 @@ impl Infer {
)) ))
} }
/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
self.tokenizer_config
.apply_chat_template(chat)
.map_err(InferError::TemplateError)
}
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) async fn generate( pub(crate) async fn generate(
@ -201,6 +207,7 @@ impl Infer {
(result_generated_text, result_queued, result_start) (result_generated_text, result_queued, result_start)
{ {
Ok(InferResponse { Ok(InferResponse {
prompt_token_count: valid_request.input_length,
prefill: result_prefill, prefill: result_prefill,
_input_length, _input_length,
tokens: result_tokens, tokens: result_tokens,
@ -550,9 +557,9 @@ fn send_responses(
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()
.zip(tokens_.logprobs.into_iter()) .zip(tokens_.logprobs)
.zip(tokens_.texts.into_iter()) .zip(tokens_.texts)
.zip(tokens_.is_special.into_iter()) .zip(tokens_.is_special)
.enumerate() .enumerate()
.peekable(); .peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() { while let Some((i, (((id, logprob), text), special))) = iterator.next() {
@ -665,6 +672,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError), ValidationError(#[from] ValidationError),
#[error("Incomplete generation")] #[error("Incomplete generation")]
IncompleteGeneration, IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
} }
impl InferError { impl InferError {
@ -674,6 +683,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded", InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation", InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation", InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
} }
} }
} }

View File

@ -5,12 +5,22 @@ mod queue;
pub mod server; pub mod server;
mod validation; mod validation;
use infer::Infer; use crate::validation::ValidGenerateRequest;
use infer::{Infer, InferError, InferStreamResponse};
use queue::{Entry, Queue}; use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
ValidGenerateRequest,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);
/// Hub type /// Hub type
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo { pub struct HubModelInfo {
@ -20,6 +30,28 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>, pub pipeline_tag: Option<String>,
} }
#[derive(Clone, Deserialize)]
pub struct HubTokenizerConfig {
#[serde(default)]
pub chat_template: Option<String>,
}
impl HubTokenizerConfig {
/// Apply the chat template to the chat request
pub(crate) fn apply_chat_template(
&self,
chat: ChatRequest,
) -> Result<String, minijinja::Error> {
let mut env = minijinja::Environment::new();
let chat_template = self
.chat_template
.as_ref()
.ok_or(minijinja::ErrorKind::TemplateNotFound)?;
env.add_template("_", chat_template)?;
env.get_template("_")?.render(chat)
}
}
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info { pub struct Info {
/// Model info /// Model info
@ -152,7 +184,7 @@ fn default_parameters() -> GenerateParameters {
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None, typical_p: None,
do_sample: false, do_sample: true,
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: Vec::new(),
@ -165,6 +197,190 @@ fn default_parameters() -> GenerateParameters {
} }
} }
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ChatCompletion {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionComplete>,
pub usage: Usage,
}
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ChatCompletionComplete {
pub index: u32,
pub message: Message,
pub logprobs: Option<Vec<f32>>,
pub finish_reason: String,
}
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl ChatCompletion {
pub(crate) fn new(
model: String,
system_fingerprint: String,
output: String,
created: u64,
details: Details,
return_logprobs: bool,
) -> Self {
Self {
id: String::new(),
object: "text_completion".into(),
created,
model,
system_fingerprint,
choices: vec![ChatCompletionComplete {
index: 0,
message: Message {
role: "assistant".into(),
content: output,
},
logprobs: return_logprobs
.then(|| details.tokens.iter().map(|t| t.logprob).collect()),
finish_reason: details.finish_reason.to_string(),
}],
usage: Usage {
prompt_tokens: details.prompt_token_count,
completion_tokens: details.generated_tokens,
total_tokens: details.prompt_token_count + details.generated_tokens,
},
}
}
}
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>,
}
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ChatCompletionChoice {
pub index: u32,
pub delta: ChatCompletionDelta,
pub logprobs: Option<f32>,
pub finish_reason: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct ChatCompletionDelta {
pub role: String,
pub content: String,
}
impl ChatCompletionChunk {
pub(crate) fn new(
model: String,
system_fingerprint: String,
delta: String,
created: u64,
index: u32,
logprobs: Option<f32>,
finish_reason: Option<String>,
) -> Self {
Self {
id: "".to_string(),
object: "text_completion".to_string(),
created,
model,
system_fingerprint,
choices: vec![ChatCompletionChoice {
index,
delta: ChatCompletionDelta {
role: "assistant".to_string(),
content: delta,
},
logprobs,
finish_reason,
}],
}
}
}
fn default_request_messages() -> Vec<Message> {
vec![Message {
role: "system".to_string(),
content: "My name is David and I".to_string(),
}]
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct ChatRequest {
/// UNUSED
#[schema(example = "bigscience/blomm-560m")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */
/// A list of messages comprising the conversation so far.
#[serde(default = "default_request_messages")]
pub messages: Vec<Message>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
pub frequency_penalty: Option<f32>,
/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token.
#[serde(default)]
pub logit_bias: Option<Vec<f32>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
#[serde(default)]
pub logprobs: Option<bool>,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
pub max_tokens: Option<u32>,
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
pub n: Option<u32>,
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
pub presence_penalty: Option<f32>,
#[serde(default = "bool::default")]
pub stream: bool,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message {
#[schema(example = "system")]
pub role: String,
#[schema(example = "My name is David and I")]
pub content: String,
}
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateRequest { pub(crate) struct GenerateRequest {
#[schema(example = "My name is Olivier and I")] #[schema(example = "My name is Olivier and I")]
@ -227,6 +443,16 @@ pub(crate) enum FinishReason {
StopSequence, StopSequence,
} }
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FinishReason::Length => write!(f, "length"),
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
FinishReason::StopSequence => write!(f, "stop_sequence"),
}
}
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence { pub(crate) struct BestOfSequence {
#[schema(example = "test")] #[schema(example = "test")]
@ -257,6 +483,8 @@ pub(crate) struct Details {
pub best_of_sequences: Option<Vec<BestOfSequence>>, pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>, pub top_tokens: Vec<Vec<Token>>,
#[schema(example = 1)]
pub prompt_token_count: u32,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -279,6 +507,7 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub index: u32,
pub token: Token, pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>, pub top_tokens: Vec<Token>,

View File

@ -1,22 +1,19 @@
/// Text Generation Inference webserver entrypoint
use axum::http::HeaderValue; use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource; use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue}; use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
/// Text Generation Inference webserver entrypoint
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use std::time::Duration;
use text_generation_client::{ClientError, ShardedClient}; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::{server, HubModelInfo}; use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
@ -72,8 +69,7 @@ struct Args {
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
} }
#[tokio::main] fn main() -> Result<(), RouterError> {
async fn main() -> Result<(), RouterError> {
// Get args // Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
@ -102,9 +98,6 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge, ngrok_edge,
} = args; } = args;
// Launch Tokio runtime
init_logging(otlp_endpoint, json_output);
// Validate args // Validate args
if max_input_length >= max_total_tokens { if max_input_length >= max_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
@ -148,158 +141,161 @@ async fn main() -> Result<(), RouterError> {
// This will only be used to validate payloads // This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir(); let local_model = local_path.exists() && local_path.is_dir();
let tokenizer = if local_model {
let (tokenizer, model_info) = if local_model {
// Get Model info
let model_info = HubModelInfo {
model_id: tokenizer_name.clone(),
sha: None,
pipeline_tag: None,
};
// Load local tokenizer // Load local tokenizer
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
(tokenizer, model_info)
} else { } else {
let mut builder = ApiBuilder::new() // Download and instantiate tokenizer
.with_progress(false) // We need to download it outside of the Tokio runtime
.with_token(authorization_token); let params = FromPretrainedParameters {
revision: revision.clone().unwrap_or("main".to_string()),
if let Some(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE").ok() { auth_token: authorization_token.clone(),
builder = builder.with_cache_dir(cache_dir.into()); ..Default::default()
}
if revision.is_none() {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
}
let api = builder.build().unwrap();
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.clone(),
RepoType::Model,
revision.clone().unwrap_or("main".to_string()),
));
// Get Model info
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
});
let tokenizer = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
}; };
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
(tokenizer, model_info)
}; };
if tokenizer.is_none() { // Launch Tokio runtime
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); tokio::runtime::Builder::new_multi_thread()
tracing::warn!("Rust input length validation and truncation is disabled"); .enable_all()
} .build()?
.block_on(async {
init_logging(otlp_endpoint, json_output);
// if pipeline-tag == text-generation we default to return_full_text = true if tokenizer.is_none() {
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
false
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(RouterError::Connection)?;
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(RouterError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
)
.await
.map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!( tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \ "Could not find a fast tokenizer implementation for {tokenizer_name}"
);
tracing::warn!("Rust input length validation and truncation is disabled");
}
// Get Model info
let model_info = match local_model {
true => HubModelInfo {
model_id: tokenizer_name.clone(),
sha: None,
pipeline_tag: None,
},
false => get_model_info(&tokenizer_name, revision.as_deref(), authorization_token.as_deref())
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
}),
};
let tokenizer_config: HubTokenizerConfig = match local_model {
true => HubTokenizerConfig{
chat_template: None,
},
false => get_tokenizer_config(&tokenizer_name, revision.as_deref(), authorization_token.as_deref())
.await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
HubTokenizerConfig{
chat_template: None,
}
}),
};
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
false
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(RouterError::Connection)?;
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(RouterError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
.await
.map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models." Attention models."
); );
tracing::warn!( tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}" "Inferred max batch total tokens: {max_supported_batch_total_tokens}"
); );
} }
if max_total_tokens as u32 > max_supported_batch_total_tokens { if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
} }
max_supported_batch_total_tokens max_supported_batch_total_tokens
} }
}; };
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected"); tracing::info!("Connected");
let addr = match hostname.parse() { let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port), Ok(ip) => SocketAddr::new(ip, port),
Err(_) => { Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
} }
}; };
// Run server // Run server
server::run( server::run(
model_info, model_info,
shard_info, shard_info,
compat_return_full_text, compat_return_full_text,
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_supported_batch_total_tokens, max_supported_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
sharded_client, sharded_client,
tokenizer, tokenizer,
validation_workers, validation_workers,
addr, addr,
cors_allow_origin, cors_allow_origin,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
) tokenizer_config,
.await?; )
Ok(()) .await?;
Ok(())
})
} }
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
@ -358,8 +354,30 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
} }
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> { pub async fn get_model_info(
let response = api.info_request().send().await.ok()?; model_id: &str,
revision: Option<&str>,
token: Option<&str>,
) -> Option<HubModelInfo> {
let revision = match revision {
None => {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
"main".to_string()
}
Some(revision) => revision.to_string(),
};
let client = reqwest::Client::new();
// Poor man's urlencode
let revision = revision.replace('/', "%2F");
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
let mut builder = client.get(url).timeout(Duration::from_secs(5));
if let Some(token) = token {
builder = builder.bearer_auth(token);
}
let response = builder.send().await.ok()?;
if response.status().is_success() { if response.status().is_success() {
let hub_model_info: HubModelInfo = let hub_model_info: HubModelInfo =
@ -376,26 +394,36 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
} }
} }
/// get base tokenizer /// get tokenizer_config from the Huggingface Hub
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> { pub async fn get_tokenizer_config(
let config_filename = api_repo.get("config.json").await.ok()?; model_id: &str,
revision: Option<&str>,
// Open the file in read-only mode with buffer. token: Option<&str>,
let file = File::open(config_filename).ok()?; ) -> Option<HubTokenizerConfig> {
let reader = BufReader::new(file); let revision = match revision {
None => {
// Read the JSON contents of the file as an instance of `User`. tracing::warn!("`--revision` is not set");
let config: serde_json::Value = serde_json::from_reader(reader).ok()?; tracing::warn!("We strongly advise to set it to a known supported commit.");
"main".to_string()
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { }
let api_base_repo = api.repo(Repo::with_revision( Some(revision) => revision.to_string(),
base_model_id.to_string(), };
RepoType::Model, let client = reqwest::Client::new();
"main".to_string(), // Poor man's urlencode
)); let revision = revision.replace('/', "%2F");
let url = format!(
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?; "https://huggingface.co/{}/raw/{}/tokenizer_config.json",
Tokenizer::from_file(tokenizer_filename).ok() model_id, revision
);
let mut builder = client.get(url).timeout(Duration::from_secs(5));
if let Some(token) = token {
builder = builder.bearer_auth(token);
}
let response = builder.send().await.ok()?;
if response.status().is_success() {
let text = response.text().await.ok()?;
let hub_tokenizer_config: HubTokenizerConfig = serde_json::from_str(&text).ok()?;
Some(hub_tokenizer_config)
} else { } else {
None None
} }

View File

@ -2,10 +2,11 @@
use crate::health::Health; use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::HubTokenizerConfig;
use crate::{ use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
StreamDetails, StreamResponse, Token, Validation, HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -207,6 +208,7 @@ async fn generate(
seed: response.generated_text.seed, seed: response.generated_text.seed,
best_of_sequences, best_of_sequences,
top_tokens: response.top_tokens, top_tokens: response.top_tokens,
prompt_token_count: response.prompt_token_count,
}) })
} }
false => None, false => None,
@ -343,6 +345,21 @@ async fn generate_stream(
HeaderMap, HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default();
event.json_data(stream_token).unwrap()
};
let (headers, response_stream) =
generate_stream_internal(infer, Json(req), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse)
}
async fn generate_stream_internal(
infer: Infer,
Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current(); let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
@ -387,6 +404,7 @@ async fn generate_stream(
Ok((_permit, _input_length, mut response_stream)) => { Ok((_permit, _input_length, mut response_stream)) => {
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
index += 1;
match response { match response {
Ok(response) => { Ok(response) => {
match response { match response {
@ -401,13 +419,14 @@ async fn generate_stream(
// StreamResponse // StreamResponse
let stream_token = StreamResponse { let stream_token = StreamResponse {
index,
token, token,
top_tokens, top_tokens,
generated_text: None, generated_text: None,
details: None, details: None,
}; };
let event = on_message_callback(stream_token);
yield Ok(Event::default().json_data(stream_token).unwrap()) yield Ok(event);
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
InferStreamResponse::End { InferStreamResponse::End {
@ -463,13 +482,16 @@ async fn generate_stream(
tracing::info!(parent: &span, "Success"); tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse { let stream_token = StreamResponse {
index,
token, token,
top_tokens, top_tokens,
generated_text: Some(output_text), generated_text: Some(output_text),
details details
}; };
yield Ok(Event::default().json_data(stream_token).unwrap());
let event = on_message_callback(stream_token);
yield Ok(event);
break; break;
} }
} }
@ -500,7 +522,153 @@ async fn generate_stream(
} }
}; };
(headers, Sse::new(stream).keep_alive(KeepAlive::default())) (headers, stream)
}
/// Generate tokens
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v1/chat/completions",
request_body = ChatRequest,
responses(
(status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
// parameters = ? req.parameters,
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100));
let repetition_penalty = req
.frequency_penalty
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false);
// apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req) {
Ok(inputs) => inputs,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
// build the request passing some parameters
let generate_request = GenerateRequest {
inputs: inputs.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty,
top_k: None,
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: true,
decoder_input_details: false,
seed: None,
top_n_tokens: None,
},
};
// static values that will be returned in all cases
let model_id = info.model_id.clone();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream
if stream {
// pass this callback to the stream generation and build the required event structure
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
event
.json_data(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
stream_token.token.text,
current_time,
stream_token.index,
logprobs.then_some(stream_token.token.logprob),
stream_token.details.map(|d| d.finish_reason.to_string()),
))
.map_or_else(
|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
},
|data| data,
)
};
let (headers, response_stream) =
generate_stream_internal(infer, Json(generate_request), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, Json(generation)) =
generate(Extension(infer), Json(generate_request)).await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
// build the complete response object with the full text
let response = ChatCompletion::new(
generation.generated_text,
model_id,
system_fingerprint,
current_time,
generation.details.unwrap(),
logprobs,
);
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(response)).into_response())
}
} }
/// Prometheus metrics scrape endpoint /// Prometheus metrics scrape endpoint
@ -538,6 +706,7 @@ pub async fn run(
ngrok: bool, ngrok: bool,
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -604,6 +773,7 @@ pub async fn run(
shard_info.window_size, shard_info.window_size,
shard_info.speculate, shard_info.speculate,
generation_health, generation_health,
tokenizer_config,
); );
// Duration buckets // Duration buckets
@ -693,6 +863,7 @@ pub async fn run(
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions))
// AWS Sagemaker route // AWS Sagemaker route
.route("/invocations", post(compat_generate)) .route("/invocations", post(compat_generate))
// Base Health route // Base Health route
@ -822,6 +993,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
}; };
( (

View File

@ -376,7 +376,7 @@ type TokenizerRequest = (
Span, Span,
); );
#[derive(Debug)] #[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest { pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: String,
pub input_length: u32, pub input_length: u32,