feat: support repetition_penalty and improve non stream response

This commit is contained in:
drbh 2024-01-09 13:31:15 -05:00
parent fba1953eb6
commit 8c4ab53780
2 changed files with 22 additions and 13 deletions

View File

@ -224,13 +224,19 @@ pub(crate) struct Usage {
} }
impl ChatCompletion { impl ChatCompletion {
pub(crate) fn new(ouput: String, created: u64, details: Details) -> Self { pub(crate) fn new(
model: String,
system_fingerprint: String,
ouput: String,
created: u64,
details: Details,
) -> Self {
Self { Self {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
created, created,
model: "".to_string(), model,
system_fingerprint: "".to_string(), system_fingerprint,
choices: vec![ChatCompletionComplete { choices: vec![ChatCompletionComplete {
index: 0, index: 0,
message: Message { message: Message {

View File

@ -558,12 +558,12 @@ async fn chat_completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
// extract the values we need for the chat request
let stream = req.stream; let stream = req.stream;
let max_new_tokens = match req.max_tokens { let max_new_tokens = req.max_tokens.or(Some(100));
Some(max_new_tokens) => Some(max_new_tokens), let repetition_penalty = req
None => Some(100), .frequency_penalty
}; // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0);
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req) { let inputs = match infer.apply_chat_template(req) {
@ -587,11 +587,11 @@ async fn chat_completions(
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: None, temperature: None,
repetition_penalty: None, repetition_penalty,
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None, typical_p: None,
do_sample: false, do_sample: true,
max_new_tokens, max_new_tokens,
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: Vec::new(),
@ -604,11 +604,12 @@ async fn chat_completions(
}, },
}; };
// static values that will be returned in all cases
let model_id = info.model_id.clone();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream // switch on stream
if stream { if stream {
let model_id = info.model_id.clone();
let system_fingerprint =
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// pass this callback to the stream generation and build the required event structure // pass this callback to the stream generation and build the required event structure
let on_message_callback = move |stream_token: StreamResponse| { let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
@ -650,6 +651,8 @@ async fn chat_completions(
// build the complete response object with the full text // build the complete response object with the full text
let response = ChatCompletion::new( let response = ChatCompletion::new(
generation.generated_text, generation.generated_text,
model_id,
system_fingerprint,
current_time, current_time,
generation.details.unwrap(), generation.details.unwrap(),
); );