mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: support repetition_penalty and improve non stream response
This commit is contained in:
parent
fba1953eb6
commit
8c4ab53780
@ -224,13 +224,19 @@ pub(crate) struct Usage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ChatCompletion {
|
impl ChatCompletion {
|
||||||
pub(crate) fn new(ouput: String, created: u64, details: Details) -> Self {
|
pub(crate) fn new(
|
||||||
|
model: String,
|
||||||
|
system_fingerprint: String,
|
||||||
|
ouput: String,
|
||||||
|
created: u64,
|
||||||
|
details: Details,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: "".to_string(),
|
id: "".to_string(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
created,
|
created,
|
||||||
model: "".to_string(),
|
model,
|
||||||
system_fingerprint: "".to_string(),
|
system_fingerprint,
|
||||||
choices: vec![ChatCompletionComplete {
|
choices: vec![ChatCompletionComplete {
|
||||||
index: 0,
|
index: 0,
|
||||||
message: Message {
|
message: Message {
|
||||||
|
@ -558,12 +558,12 @@ async fn chat_completions(
|
|||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
// extract the values we need for the chat request
|
|
||||||
let stream = req.stream;
|
let stream = req.stream;
|
||||||
let max_new_tokens = match req.max_tokens {
|
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||||
Some(max_new_tokens) => Some(max_new_tokens),
|
let repetition_penalty = req
|
||||||
None => Some(100),
|
.frequency_penalty
|
||||||
};
|
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||||
|
.map(|x| x + 2.0);
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let inputs = match infer.apply_chat_template(req) {
|
let inputs = match infer.apply_chat_template(req) {
|
||||||
@ -587,11 +587,11 @@ async fn chat_completions(
|
|||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
repetition_penalty: None,
|
repetition_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: false,
|
do_sample: true,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
@ -604,11 +604,12 @@ async fn chat_completions(
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// static values that will be returned in all cases
|
||||||
|
let model_id = info.model_id.clone();
|
||||||
|
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if stream {
|
if stream {
|
||||||
let model_id = info.model_id.clone();
|
|
||||||
let system_fingerprint =
|
|
||||||
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
|
||||||
// pass this callback to the stream generation and build the required event structure
|
// pass this callback to the stream generation and build the required event structure
|
||||||
let on_message_callback = move |stream_token: StreamResponse| {
|
let on_message_callback = move |stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
@ -650,6 +651,8 @@ async fn chat_completions(
|
|||||||
// build the complete response object with the full text
|
// build the complete response object with the full text
|
||||||
let response = ChatCompletion::new(
|
let response = ChatCompletion::new(
|
||||||
generation.generated_text,
|
generation.generated_text,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
current_time,
|
current_time,
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user