fix: adds index, model id, system fingerprint and updates do_sample param

This commit is contained in:
drbh 2024-01-09 11:54:20 -05:00
parent ddf7412a6b
commit f82ff3f64a
2 changed files with 33 additions and 19 deletions

View File

@ -158,7 +158,7 @@ fn default_parameters() -> GenerateParameters {
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None, typical_p: None,
do_sample: false, do_sample: true,
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: Vec::new(),
@ -253,21 +253,29 @@ pub(crate) struct ChatCompletionDelta {
} }
impl ChatCompletionChunk { impl ChatCompletionChunk {
pub(crate) fn new(delta: String, created: u64, index: u32) -> Self { pub(crate) fn new(
model: String,
system_fingerprint: String,
delta: String,
created: u64,
index: u32,
logprobs: Option<Vec<f32>>,
finish_reason: Option<String>,
) -> Self {
Self { Self {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
created, created,
model: "".to_string(), model,
system_fingerprint: "".to_string(), system_fingerprint,
choices: vec![ChatCompletionChoice { choices: vec![ChatCompletionChoice {
index, index,
delta: ChatCompletionDelta { delta: ChatCompletionDelta {
role: "assistant".to_string(), role: "assistant".to_string(),
content: delta, content: delta,
}, },
logprobs: None, logprobs,
finish_reason: None, finish_reason,
}], }],
} }
} }

View File

@ -21,7 +21,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -339,7 +338,7 @@ async fn generate_stream(
HeaderMap, HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let on_message_callback = |stream_token: StreamResponse| { let on_message_callback = |_: u32, stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
event.json_data(stream_token).unwrap() event.json_data(stream_token).unwrap()
}; };
@ -353,7 +352,7 @@ async fn generate_stream(
async fn generate_stream_internal( async fn generate_stream_internal(
infer: Infer, infer: Infer,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event, on_message_callback: impl Fn(u32, StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current(); let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
@ -397,8 +396,10 @@ async fn generate_stream_internal(
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => { Ok((_permit, mut response_stream)) => {
let mut index = 0;
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
index += 1;
match response { match response {
Ok(response) => { Ok(response) => {
match response { match response {
@ -418,8 +419,7 @@ async fn generate_stream_internal(
generated_text: None, generated_text: None,
details: None, details: None,
}; };
let event = on_message_callback(index, stream_token);
let event = on_message_callback(stream_token);
yield Ok(event); yield Ok(event);
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
@ -483,7 +483,7 @@ async fn generate_stream_internal(
}; };
let event = on_message_callback(stream_token); let event = on_message_callback(index, stream_token);
yield Ok(event); yield Ok(event);
break; break;
} }
@ -550,6 +550,7 @@ async fn generate_stream_internal(
)] )]
async fn chat_completions( async fn chat_completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
@ -605,9 +606,14 @@ async fn chat_completions(
// switch on stream // switch on stream
if stream { if stream {
let stream_count = AtomicU32::new(0); let model_id = info.model_id.clone();
let system_fingerprint = format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
);
// pass this callback to the stream generation and build the required event structure // pass this callback to the stream generation and build the required event structure
let on_message_callback = move |stream_token: StreamResponse| { let on_message_callback = move |index: u32, stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
@ -615,15 +621,15 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
// increment the stream count
stream_count.fetch_add(1, Ordering::SeqCst);
let current_stream_count = stream_count.load(Ordering::SeqCst);
event event
.json_data(ChatCompletionChunk::new( .json_data(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
stream_token.token.text, stream_token.token.text,
current_time, current_time,
current_stream_count, index,
None,
None,
)) ))
.unwrap_or_else(|_| { .unwrap_or_else(|_| {
println!("Failed to serialize ChatCompletionChunk"); println!("Failed to serialize ChatCompletionChunk");