mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: adds index, model id, system fingerprint and updates do_sample param
This commit is contained in:
parent
ddf7412a6b
commit
f82ff3f64a
@ -158,7 +158,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: false,
|
||||
do_sample: true,
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
@ -253,21 +253,29 @@ pub(crate) struct ChatCompletionDelta {
|
||||
}
|
||||
|
||||
impl ChatCompletionChunk {
|
||||
pub(crate) fn new(delta: String, created: u64, index: u32) -> Self {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
delta: String,
|
||||
created: u64,
|
||||
index: u32,
|
||||
logprobs: Option<Vec<f32>>,
|
||||
finish_reason: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created,
|
||||
model: "".to_string(),
|
||||
system_fingerprint: "".to_string(),
|
||||
model,
|
||||
system_fingerprint,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index,
|
||||
delta: ChatCompletionDelta {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -339,7 +338,7 @@ async fn generate_stream(
|
||||
HeaderMap,
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let on_message_callback = |_: u32, stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
@ -353,7 +352,7 @@ async fn generate_stream(
|
||||
async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||
on_message_callback: impl Fn(u32, StreamResponse) -> Event,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
@ -397,8 +396,10 @@ async fn generate_stream_internal(
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
let mut index = 0;
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
index += 1;
|
||||
match response {
|
||||
Ok(response) => {
|
||||
match response {
|
||||
@ -418,8 +419,7 @@ async fn generate_stream_internal(
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
|
||||
let event = on_message_callback(stream_token);
|
||||
let event = on_message_callback(index, stream_token);
|
||||
yield Ok(event);
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
@ -483,7 +483,7 @@ async fn generate_stream_internal(
|
||||
};
|
||||
|
||||
|
||||
let event = on_message_callback(stream_token);
|
||||
let event = on_message_callback(index, stream_token);
|
||||
yield Ok(event);
|
||||
break;
|
||||
}
|
||||
@ -550,6 +550,7 @@ async fn generate_stream_internal(
|
||||
)]
|
||||
async fn chat_completions(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(info): Extension<Info>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
@ -605,9 +606,14 @@ async fn chat_completions(
|
||||
|
||||
// switch on stream
|
||||
if stream {
|
||||
let stream_count = AtomicU32::new(0);
|
||||
let model_id = info.model_id.clone();
|
||||
let system_fingerprint = format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
);
|
||||
// pass this callback to the stream generation and build the required event structure
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let on_message_callback = move |index: u32, stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
@ -615,15 +621,15 @@ async fn chat_completions(
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
// increment the stream count
|
||||
stream_count.fetch_add(1, Ordering::SeqCst);
|
||||
let current_stream_count = stream_count.load(Ordering::SeqCst);
|
||||
|
||||
event
|
||||
.json_data(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
stream_token.token.text,
|
||||
current_time,
|
||||
current_stream_count,
|
||||
index,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
.unwrap_or_else(|_| {
|
||||
println!("Failed to serialize ChatCompletionChunk");
|
||||
|
Loading…
Reference in New Issue
Block a user