mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support logprobs in streaming and non streaming chat
This commit is contained in:
parent
65c913b55d
commit
9a79c2f867
@ -213,7 +213,7 @@ pub(crate) struct ChatCompletionComplete {
|
|||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub message: Message,
|
pub message: Message,
|
||||||
pub logprobs: Option<Vec<f32>>,
|
pub logprobs: Option<Vec<f32>>,
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
@ -227,24 +227,26 @@ impl ChatCompletion {
|
|||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
ouput: String,
|
output: String,
|
||||||
created: u64,
|
created: u64,
|
||||||
details: Details,
|
details: Details,
|
||||||
|
return_logprobs: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: "".to_string(),
|
id: String::new(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".into(),
|
||||||
created,
|
created,
|
||||||
model,
|
model,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
choices: vec![ChatCompletionComplete {
|
choices: vec![ChatCompletionComplete {
|
||||||
index: 0,
|
index: 0,
|
||||||
message: Message {
|
message: Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".into(),
|
||||||
content: ouput,
|
content: output,
|
||||||
},
|
},
|
||||||
logprobs: None,
|
logprobs: return_logprobs
|
||||||
finish_reason: details.finish_reason.to_string().into(),
|
.then(|| details.tokens.iter().map(|t| t.logprob).collect()),
|
||||||
|
finish_reason: details.finish_reason.to_string(),
|
||||||
}],
|
}],
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
prompt_tokens: details.prompt_token_count,
|
prompt_tokens: details.prompt_token_count,
|
||||||
@ -269,7 +271,7 @@ pub(crate) struct ChatCompletionChunk {
|
|||||||
pub(crate) struct ChatCompletionChoice {
|
pub(crate) struct ChatCompletionChoice {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub delta: ChatCompletionDelta,
|
pub delta: ChatCompletionDelta,
|
||||||
pub logprobs: Option<Vec<f32>>,
|
pub logprobs: Option<f32>,
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -286,7 +288,7 @@ impl ChatCompletionChunk {
|
|||||||
delta: String,
|
delta: String,
|
||||||
created: u64,
|
created: u64,
|
||||||
index: u32,
|
index: u32,
|
||||||
logprobs: Option<Vec<f32>>,
|
logprobs: Option<f32>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -340,12 +342,10 @@ pub(crate) struct ChatRequest {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub logit_bias: Option<Vec<f32>>,
|
pub logit_bias: Option<Vec<f32>>,
|
||||||
|
|
||||||
/// UNUSED
|
|
||||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||||
/// output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview
|
/// output token returned in the content of message.
|
||||||
/// model.
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub logprobs: Option<u32>,
|
pub logprobs: Option<bool>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||||
|
@ -564,6 +564,7 @@ async fn chat_completions(
|
|||||||
.frequency_penalty
|
.frequency_penalty
|
||||||
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||||
.map(|x| x + 2.0);
|
.map(|x| x + 2.0);
|
||||||
|
let logprobs = req.logprobs.unwrap_or(false);
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let inputs = match infer.apply_chat_template(req) {
|
let inputs = match infer.apply_chat_template(req) {
|
||||||
@ -626,13 +627,16 @@ async fn chat_completions(
|
|||||||
stream_token.token.text,
|
stream_token.token.text,
|
||||||
current_time,
|
current_time,
|
||||||
stream_token.index,
|
stream_token.index,
|
||||||
None,
|
logprobs.then(|| stream_token.token.logprob),
|
||||||
stream_token.details.map(|d| d.finish_reason.to_string()),
|
stream_token.details.map(|d| d.finish_reason.to_string()),
|
||||||
))
|
))
|
||||||
.unwrap_or_else(|e| {
|
.map_or_else(
|
||||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
|e| {
|
||||||
Event::default()
|
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||||
})
|
Event::default()
|
||||||
|
},
|
||||||
|
|data| data,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
@ -655,6 +659,7 @@ async fn chat_completions(
|
|||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
current_time,
|
current_time,
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
|
logprobs,
|
||||||
);
|
);
|
||||||
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
Loading…
Reference in New Issue
Block a user