fix: adds index, model id, system fingerprint and updates do_sample param

This commit is contained in:
drbh 2024-01-09 11:54:20 -05:00
parent ddf7412a6b
commit f82ff3f64a
2 changed files with 33 additions and 19 deletions

View File

@ -158,7 +158,7 @@ fn default_parameters() -> GenerateParameters {
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
do_sample: true,
max_new_tokens: default_max_new_tokens(),
return_full_text: None,
stop: Vec::new(),
@ -253,21 +253,29 @@ pub(crate) struct ChatCompletionDelta {
}
impl ChatCompletionChunk {
pub(crate) fn new(delta: String, created: u64, index: u32) -> Self {
pub(crate) fn new(
model: String,
system_fingerprint: String,
delta: String,
created: u64,
index: u32,
logprobs: Option<Vec<f32>>,
finish_reason: Option<String>,
) -> Self {
Self {
id: "".to_string(),
object: "text_completion".to_string(),
created,
model: "".to_string(),
system_fingerprint: "".to_string(),
model,
system_fingerprint,
choices: vec![ChatCompletionChoice {
index,
delta: ChatCompletionDelta {
role: "assistant".to_string(),
content: delta,
},
logprobs: None,
finish_reason: None,
logprobs,
finish_reason,
}],
}
}

View File

@ -21,7 +21,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer;
@ -339,7 +338,7 @@ async fn generate_stream(
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let on_message_callback = |stream_token: StreamResponse| {
let on_message_callback = |_: u32, stream_token: StreamResponse| {
let event = Event::default();
event.json_data(stream_token).unwrap()
};
@ -353,7 +352,7 @@ async fn generate_stream(
async fn generate_stream_internal(
infer: Infer,
Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event,
on_message_callback: impl Fn(u32, StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current();
let start_time = Instant::now();
@ -397,8 +396,10 @@ async fn generate_stream_internal(
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
let mut index = 0;
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
index += 1;
match response {
Ok(response) => {
match response {
@ -418,8 +419,7 @@ async fn generate_stream_internal(
generated_text: None,
details: None,
};
let event = on_message_callback(stream_token);
let event = on_message_callback(index, stream_token);
yield Ok(event);
}
// Yield event for last token and compute timings
@ -483,7 +483,7 @@ async fn generate_stream_internal(
};
let event = on_message_callback(stream_token);
let event = on_message_callback(index, stream_token);
yield Ok(event);
break;
}
@ -550,6 +550,7 @@ async fn generate_stream_internal(
)]
async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
@ -605,9 +606,14 @@ async fn chat_completions(
// switch on stream
if stream {
let stream_count = AtomicU32::new(0);
let model_id = info.model_id.clone();
let system_fingerprint = format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
);
// pass this callback to the stream generation and build the required event structure
let on_message_callback = move |stream_token: StreamResponse| {
let on_message_callback = move |index: u32, stream_token: StreamResponse| {
let event = Event::default();
let current_time = std::time::SystemTime::now()
@ -615,15 +621,15 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
// increment the stream count
stream_count.fetch_add(1, Ordering::SeqCst);
let current_stream_count = stream_count.load(Ordering::SeqCst);
event
.json_data(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
stream_token.token.text,
current_time,
current_stream_count,
index,
None,
None,
))
.unwrap_or_else(|_| {
println!("Failed to serialize ChatCompletionChunk");