mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Change deltas too.
This commit is contained in:
parent
2a87dd7274
commit
961a873305
@ -11,6 +11,7 @@ use queue::{Entry, Queue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
@ -539,9 +540,25 @@ impl ChatCompletion {
|
||||
content: vec![MessageChunk::Text(Text { text: output })],
|
||||
name: None,
|
||||
}),
|
||||
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { tool_calls }),
|
||||
_ => {
|
||||
todo!("Implement error for invalid tool vs chat");
|
||||
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls,
|
||||
}),
|
||||
(Some(output), Some(_)) => {
|
||||
warn!("Received both chat and tool call");
|
||||
OutputMessage::ChatMessage(Message {
|
||||
role: "assistant".into(),
|
||||
content: vec![MessageChunk::Text(Text { text: output })],
|
||||
name: None,
|
||||
})
|
||||
}
|
||||
(None, None) => {
|
||||
warn!("Didn't receive an answer");
|
||||
OutputMessage::ChatMessage(Message {
|
||||
role: "assistant".into(),
|
||||
content: vec![],
|
||||
name: None,
|
||||
})
|
||||
}
|
||||
};
|
||||
Self {
|
||||
@ -574,6 +591,7 @@ pub(crate) struct CompletionCompleteChunk {
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
@ -594,21 +612,20 @@ pub(crate) struct ChatCompletionChoice {
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionDelta {
|
||||
#[schema(example = "user")]
|
||||
// TODO Modify this to a true enum.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub content: Option<String>,
|
||||
// default to None
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<DeltaToolCall>,
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
pub struct ToolCallDelta {
|
||||
#[schema(example = "assistant")]
|
||||
role: String,
|
||||
tool_calls: DeltaToolCall,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
enum ChatCompletionDelta {
|
||||
Chat(TextMessage),
|
||||
Tool(ToolCallDelta),
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||
pub(crate) struct DeltaToolCall {
|
||||
pub index: u32,
|
||||
pub id: String,
|
||||
@ -616,7 +633,7 @@ pub(crate) struct DeltaToolCall {
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||
pub(crate) struct Function {
|
||||
pub name: Option<String>,
|
||||
pub arguments: String,
|
||||
@ -634,15 +651,13 @@ impl ChatCompletionChunk {
|
||||
finish_reason: Option<String>,
|
||||
) -> Self {
|
||||
let delta = match (delta, tool_calls) {
|
||||
(Some(delta), _) => ChatCompletionDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(delta),
|
||||
tool_calls: None,
|
||||
},
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: Some(DeltaToolCall {
|
||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
}),
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
@ -650,13 +665,12 @@ impl ChatCompletionChunk {
|
||||
name: None,
|
||||
arguments: tool_calls[0].to_string(),
|
||||
},
|
||||
}),
|
||||
},
|
||||
(None, None) => ChatCompletionDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
},
|
||||
}),
|
||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "".to_string(),
|
||||
}),
|
||||
};
|
||||
Self {
|
||||
id: String::new(),
|
||||
@ -982,6 +996,8 @@ impl From<Message> for TextMessage {
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
pub struct ToolCallMessage {
|
||||
#[schema(example = "assistant")]
|
||||
role: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
tool_call_id: String,
|
||||
}
|
||||
@ -1265,6 +1281,7 @@ mod tests {
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
}
|
||||
]
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user