mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: supports openai chat completions API
This commit is contained in:
parent
630800eed3
commit
3ae9cd655d
39
Cargo.lock
generated
39
Cargo.lock
generated
@ -773,9 +773,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.29"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb"
|
||||
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
@ -783,9 +783,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.29"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c"
|
||||
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
@ -800,15 +800,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.29"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa"
|
||||
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.29"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb"
|
||||
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -817,21 +817,21 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.29"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817"
|
||||
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.29"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2"
|
||||
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.29"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104"
|
||||
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@ -1369,6 +1369,15 @@ dependencies = [
|
||||
"unicase",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minijinja"
|
||||
version = "1.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "208758577ef2c86cf5dd3e85730d161413ec3284e2d73b2ef65d9a24d9971bcb"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minimal-lexical"
|
||||
version = "0.2.1"
|
||||
@ -2803,10 +2812,12 @@ dependencies = [
|
||||
"axum-tracing-opentelemetry",
|
||||
"clap",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"hf-hub",
|
||||
"init-tracing-opentelemetry",
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"minijinja",
|
||||
"ngrok",
|
||||
"nohash-hasher",
|
||||
"opentelemetry",
|
||||
|
@ -43,6 +43,8 @@ utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
hf-hub = "0.3.1"
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = "1.0.10"
|
||||
futures-util = "0.3.30"
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||
|
@ -1,7 +1,8 @@
|
||||
/// Batching and inference logic
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::HubTokenizerConfig;
|
||||
use crate::{ChatRequest, GenerateRequest, PrefillToken};
|
||||
use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use futures::future::try_join_all;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
@ -26,6 +27,8 @@ pub struct Infer {
|
||||
validation: Validation,
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Chat formatter
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
/// Shared state
|
||||
shared: Arc<Shared>,
|
||||
/// Inference limit
|
||||
@ -52,6 +55,7 @@ impl Infer {
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||
@ -79,6 +83,7 @@ impl Infer {
|
||||
queue,
|
||||
shared,
|
||||
limit_concurrent_requests: semaphore,
|
||||
tokenizer_config,
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,6 +138,28 @@ impl Infer {
|
||||
Ok((permit, UnboundedReceiverStream::new(response_rx)))
|
||||
}
|
||||
|
||||
/// Apply the chat template to the chat request
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn apply_chat_template(
|
||||
&self,
|
||||
chat: ChatRequest,
|
||||
) -> Result<String, ChatTemplateError> {
|
||||
let mut env = minijinja::Environment::new();
|
||||
let chat_template = self
|
||||
.tokenizer_config
|
||||
.chat_template
|
||||
.as_ref()
|
||||
.ok_or(ChatTemplateError::TemplateNotFound)?;
|
||||
env.add_template("_", chat_template)
|
||||
.map_err(|e| ChatTemplateError::TemplateError(e))?;
|
||||
let jinja_tmpl = env
|
||||
.get_template("_")
|
||||
.map_err(|e| ChatTemplateError::TemplateError(e))?;
|
||||
jinja_tmpl
|
||||
.render(chat)
|
||||
.map_err(|e| ChatTemplateError::TemplateError(e))
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn generate(
|
||||
@ -666,3 +693,20 @@ impl InferError {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ChatTemplateError {
|
||||
#[error("Template error: {0}")]
|
||||
TemplateError(#[from] minijinja::Error),
|
||||
#[error("Template not found")]
|
||||
TemplateNotFound,
|
||||
}
|
||||
|
||||
impl ChatTemplateError {
|
||||
pub(crate) fn error_type(&self) -> &str {
|
||||
match self {
|
||||
ChatTemplateError::TemplateError(_) => "template_error",
|
||||
ChatTemplateError::TemplateNotFound => "template_not_found",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,6 +20,12 @@ pub struct HubModelInfo {
|
||||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct HubTokenizerConfig {
|
||||
#[serde(default)]
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
@ -165,6 +171,182 @@ fn default_parameters() -> GenerateParameters {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub(crate) struct ChatCompletion {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
pub choices: Vec<ChatCompletionComplete>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub(crate) struct ChatCompletionComplete {
|
||||
pub index: u32,
|
||||
pub message: Message,
|
||||
pub logprobs: Option<Vec<f32>>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub(crate) struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
impl ChatCompletion {
|
||||
pub(crate) fn new(
|
||||
ouput: String,
|
||||
created: u64,
|
||||
details: Details,
|
||||
prompt_character_count: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created,
|
||||
model: "".to_string(),
|
||||
system_fingerprint: "".to_string(),
|
||||
choices: vec![ChatCompletionComplete {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: ouput,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: prompt_character_count,
|
||||
completion_tokens: details.generated_tokens,
|
||||
total_tokens: prompt_character_count + details.generated_tokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub(crate) struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
pub choices: Vec<ChatCompletionChoice>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub(crate) struct ChatCompletionChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatCompletionDelta,
|
||||
pub logprobs: Option<Vec<f32>>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct ChatCompletionDelta {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatCompletionChunk {
|
||||
pub(crate) fn new(delta: String, created: u64, index: u32) -> Self {
|
||||
Self {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created,
|
||||
model: "".to_string(),
|
||||
system_fingerprint: "".to_string(),
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index,
|
||||
delta: ChatCompletionDelta {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_request_messages() -> Vec<Message> {
|
||||
vec![Message {
|
||||
role: "system".to_string(),
|
||||
content: "My name is David and I".to_string(),
|
||||
}]
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct ChatRequest {
|
||||
/// UNUSED
|
||||
#[schema(example = "bigscience/blomm-560m")]
|
||||
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||
pub model: String, /* NOTE: UNUSED */
|
||||
|
||||
/// A list of messages comprising the conversation so far.
|
||||
#[serde(default = "default_request_messages")]
|
||||
pub messages: Vec<Message>,
|
||||
|
||||
/// UNUSED
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||
#[serde(default)]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// UNUSED
|
||||
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||
/// result in a ban or exclusive selection of the relevant token.
|
||||
#[serde(default)]
|
||||
pub logit_bias: Option<Vec<f32>>,
|
||||
|
||||
/// UNUSED
|
||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||
/// output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview
|
||||
/// model.
|
||||
#[serde(default)]
|
||||
pub logprobs: Option<u32>,
|
||||
|
||||
/// UNUSED
|
||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||
#[serde(default)]
|
||||
pub top_logprobs: Option<u32>,
|
||||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// UNUSED
|
||||
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||||
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||
#[serde(default)]
|
||||
pub n: Option<u32>,
|
||||
|
||||
/// UNUSED
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||
/// increasing the model's likelihood to talk about new topics
|
||||
#[serde(default)]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
#[serde(default = "bool::default")]
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct Message {
|
||||
#[schema(example = "system")]
|
||||
pub role: String,
|
||||
#[schema(example = "My name is David and I")]
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct GenerateRequest {
|
||||
#[schema(example = "My name is Olivier and I")]
|
||||
|
@ -11,7 +11,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::{server, HubModelInfo};
|
||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||
use thiserror::Error;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
use tower_http::cors::AllowOrigin;
|
||||
@ -176,7 +176,7 @@ fn main() -> Result<(), RouterError> {
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
},
|
||||
false => get_model_info(&tokenizer_name, revision, authorization_token)
|
||||
false => get_model_info(&tokenizer_name, revision.as_deref(), authorization_token.as_deref())
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
@ -188,6 +188,20 @@ fn main() -> Result<(), RouterError> {
|
||||
}),
|
||||
};
|
||||
|
||||
let tokenizer_config: HubTokenizerConfig = match local_model {
|
||||
true => HubTokenizerConfig{
|
||||
chat_template: None,
|
||||
},
|
||||
false => get_tokenizer_config(&tokenizer_name, revision.as_deref(), authorization_token.as_deref())
|
||||
.await.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
|
||||
HubTokenizerConfig{
|
||||
chat_template: None,
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
|
||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||
None => {
|
||||
@ -277,6 +291,7 @@ fn main() -> Result<(), RouterError> {
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
@ -341,8 +356,8 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_model_info(
|
||||
model_id: &str,
|
||||
revision: Option<String>,
|
||||
token: Option<String>,
|
||||
revision: Option<&str>,
|
||||
token: Option<&str>,
|
||||
) -> Option<HubModelInfo> {
|
||||
let revision = match revision {
|
||||
None => {
|
||||
@ -350,7 +365,7 @@ pub async fn get_model_info(
|
||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
||||
"main".to_string()
|
||||
}
|
||||
Some(revision) => revision,
|
||||
Some(revision) => revision.to_string(),
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
@ -379,6 +394,41 @@ pub async fn get_model_info(
|
||||
}
|
||||
}
|
||||
|
||||
/// get tokenizer_config from the Huggingface Hub
|
||||
pub async fn get_tokenizer_config(
|
||||
model_id: &str,
|
||||
revision: Option<&str>,
|
||||
token: Option<&str>,
|
||||
) -> Option<HubTokenizerConfig> {
|
||||
let revision = match revision {
|
||||
None => {
|
||||
tracing::warn!("`--revision` is not set");
|
||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
||||
"main".to_string()
|
||||
}
|
||||
Some(revision) => revision.to_string(),
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
// Poor man's urlencode
|
||||
let revision = revision.replace('/', "%2F");
|
||||
let url = format!(
|
||||
"https://huggingface.co/{}/raw/{}/tokenizer_config.json",
|
||||
model_id, revision
|
||||
);
|
||||
let mut builder = client.get(url).timeout(Duration::from_secs(5));
|
||||
if let Some(token) = token {
|
||||
builder = builder.bearer_auth(token);
|
||||
}
|
||||
let response = builder.send().await.ok()?;
|
||||
if response.status().is_success() {
|
||||
let text = response.text().await.ok()?;
|
||||
let hub_tokenizer_config: HubTokenizerConfig = serde_json::from_str(&text).ok()?;
|
||||
Some(hub_tokenizer_config)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
|
@ -2,10 +2,11 @@
|
||||
use crate::health::Health;
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::HubTokenizerConfig;
|
||||
use crate::{
|
||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
|
||||
StreamDetails, StreamResponse, Token, Validation,
|
||||
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||
HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
@ -337,6 +338,22 @@ async fn generate_stream(
|
||||
HeaderMap,
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, Json(req.into()), on_message_callback).await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
(headers, sse)
|
||||
}
|
||||
|
||||
///
|
||||
async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
@ -401,7 +418,8 @@ async fn generate_stream(
|
||||
details: None,
|
||||
};
|
||||
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
InferStreamResponse::End {
|
||||
@ -463,7 +481,9 @@ async fn generate_stream(
|
||||
details
|
||||
};
|
||||
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -494,7 +514,141 @@ async fn generate_stream(
|
||||
}
|
||||
};
|
||||
|
||||
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
|
||||
(headers, stream)
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v1/chat/completions",
|
||||
request_body = ChatRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
// parameters = ? req.parameters,
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token,
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn chat(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
// extract the values we need for the chat request
|
||||
let stream = req.stream;
|
||||
let max_new_tokens = match req.max_tokens {
|
||||
Some(max_new_tokens) => Some(max_new_tokens),
|
||||
None => Some(100)
|
||||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(req) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// poor man's token count (assumes that each character is a token)
|
||||
let prompt_character_count: u32 = inputs.chars().count().try_into().unwrap_or_default();
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: false,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
top_n_tokens: None,
|
||||
},
|
||||
};
|
||||
|
||||
// switch on stream
|
||||
if stream {
|
||||
// pass this callback to the stream generation and build the required event structure
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
event
|
||||
.json_data(ChatCompletionChunk::new(
|
||||
stream_token.token.text,
|
||||
current_time,
|
||||
0,
|
||||
))
|
||||
.unwrap_or_else(|_| {
|
||||
println!("Failed to serialize ChatCompletionChunk");
|
||||
Event::default()
|
||||
})
|
||||
};
|
||||
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, Json(generate_request), on_message_callback).await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) =
|
||||
generate(Extension(infer), Json(generate_request)).await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
// build the complete response object with the full text
|
||||
let response = ChatCompletion::new(
|
||||
generation.generated_text,
|
||||
current_time,
|
||||
generation.details.unwrap(),
|
||||
prompt_character_count,
|
||||
);
|
||||
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
/// Prometheus metrics scrape endpoint
|
||||
@ -532,6 +686,7 @@ pub async fn run(
|
||||
ngrok: bool,
|
||||
ngrok_authtoken: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -598,6 +753,7 @@ pub async fn run(
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
generation_health,
|
||||
tokenizer_config,
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
@ -687,6 +843,7 @@ pub async fn run(
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat))
|
||||
// AWS Sagemaker route
|
||||
.route("/invocations", post(compat_generate))
|
||||
// Base Health route
|
||||
|
Loading…
Reference in New Issue
Block a user