mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: adds index, model id, system fingerprint and updates do_sample param
This commit is contained in:
parent
ddf7412a6b
commit
f82ff3f64a
@ -158,7 +158,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: false,
|
do_sample: true,
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
@ -253,21 +253,29 @@ pub(crate) struct ChatCompletionDelta {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ChatCompletionChunk {
|
impl ChatCompletionChunk {
|
||||||
pub(crate) fn new(delta: String, created: u64, index: u32) -> Self {
|
pub(crate) fn new(
|
||||||
|
model: String,
|
||||||
|
system_fingerprint: String,
|
||||||
|
delta: String,
|
||||||
|
created: u64,
|
||||||
|
index: u32,
|
||||||
|
logprobs: Option<Vec<f32>>,
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: "".to_string(),
|
id: "".to_string(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
created,
|
created,
|
||||||
model: "".to_string(),
|
model,
|
||||||
system_fingerprint: "".to_string(),
|
system_fingerprint,
|
||||||
choices: vec![ChatCompletionChoice {
|
choices: vec![ChatCompletionChoice {
|
||||||
index,
|
index,
|
||||||
delta: ChatCompletionDelta {
|
delta: ChatCompletionDelta {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: delta,
|
content: delta,
|
||||||
},
|
},
|
||||||
logprobs: None,
|
logprobs,
|
||||||
finish_reason: None,
|
finish_reason,
|
||||||
}],
|
}],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
|||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
use std::sync::atomic::{AtomicU32, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
@ -339,7 +338,7 @@ async fn generate_stream(
|
|||||||
HeaderMap,
|
HeaderMap,
|
||||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
) {
|
) {
|
||||||
let on_message_callback = |stream_token: StreamResponse| {
|
let on_message_callback = |_: u32, stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
event.json_data(stream_token).unwrap()
|
event.json_data(stream_token).unwrap()
|
||||||
};
|
};
|
||||||
@ -353,7 +352,7 @@ async fn generate_stream(
|
|||||||
async fn generate_stream_internal(
|
async fn generate_stream_internal(
|
||||||
infer: Infer,
|
infer: Infer,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
on_message_callback: impl Fn(u32, StreamResponse) -> Event,
|
||||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
@ -397,8 +396,10 @@ async fn generate_stream_internal(
|
|||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, mut response_stream)) => {
|
Ok((_permit, mut response_stream)) => {
|
||||||
|
let mut index = 0;
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
|
index += 1;
|
||||||
match response {
|
match response {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
match response {
|
match response {
|
||||||
@ -418,8 +419,7 @@ async fn generate_stream_internal(
|
|||||||
generated_text: None,
|
generated_text: None,
|
||||||
details: None,
|
details: None,
|
||||||
};
|
};
|
||||||
|
let event = on_message_callback(index, stream_token);
|
||||||
let event = on_message_callback(stream_token);
|
|
||||||
yield Ok(event);
|
yield Ok(event);
|
||||||
}
|
}
|
||||||
// Yield event for last token and compute timings
|
// Yield event for last token and compute timings
|
||||||
@ -483,7 +483,7 @@ async fn generate_stream_internal(
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
let event = on_message_callback(stream_token);
|
let event = on_message_callback(index, stream_token);
|
||||||
yield Ok(event);
|
yield Ok(event);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -550,6 +550,7 @@ async fn generate_stream_internal(
|
|||||||
)]
|
)]
|
||||||
async fn chat_completions(
|
async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(req): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
@ -605,9 +606,14 @@ async fn chat_completions(
|
|||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if stream {
|
if stream {
|
||||||
let stream_count = AtomicU32::new(0);
|
let model_id = info.model_id.clone();
|
||||||
|
let system_fingerprint = format!(
|
||||||
|
"{}-{}",
|
||||||
|
info.version,
|
||||||
|
info.docker_label.unwrap_or("native")
|
||||||
|
);
|
||||||
// pass this callback to the stream generation and build the required event structure
|
// pass this callback to the stream generation and build the required event structure
|
||||||
let on_message_callback = move |stream_token: StreamResponse| {
|
let on_message_callback = move |index: u32, stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
@ -615,15 +621,15 @@ async fn chat_completions(
|
|||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
// increment the stream count
|
|
||||||
stream_count.fetch_add(1, Ordering::SeqCst);
|
|
||||||
let current_stream_count = stream_count.load(Ordering::SeqCst);
|
|
||||||
|
|
||||||
event
|
event
|
||||||
.json_data(ChatCompletionChunk::new(
|
.json_data(ChatCompletionChunk::new(
|
||||||
|
model_id.clone(),
|
||||||
|
system_fingerprint.clone(),
|
||||||
stream_token.token.text,
|
stream_token.token.text,
|
||||||
current_time,
|
current_time,
|
||||||
current_stream_count,
|
index,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
))
|
))
|
||||||
.unwrap_or_else(|_| {
|
.unwrap_or_else(|_| {
|
||||||
println!("Failed to serialize ChatCompletionChunk");
|
println!("Failed to serialize ChatCompletionChunk");
|
||||||
|
Loading…
Reference in New Issue
Block a user