mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Make tool_call a list for streaming case
This commit is contained in:
parent
3495248d87
commit
f709466767
@ -727,7 +727,7 @@ pub(crate) struct ChatCompletionChoice {
|
|||||||
pub struct ToolCallDelta {
|
pub struct ToolCallDelta {
|
||||||
#[schema(example = "assistant")]
|
#[schema(example = "assistant")]
|
||||||
role: String,
|
role: String,
|
||||||
tool_calls: DeltaToolCall,
|
tool_calls: Vec<DeltaToolCall>, // Changed to Vec<DeltaToolCall>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
@ -770,15 +770,15 @@ impl ChatCompletionChunk {
|
|||||||
}),
|
}),
|
||||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
tool_calls: DeltaToolCall {
|
tool_calls: tool_calls.iter().enumerate().map(|(index, tool_call)| DeltaToolCall {
|
||||||
index: 0,
|
index: index as u32,
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: Function {
|
function: Function {
|
||||||
name: None,
|
name: None,
|
||||||
arguments: tool_calls[0].to_string(),
|
arguments: tool_call.to_string(),
|
||||||
},
|
},
|
||||||
},
|
}).collect(),
|
||||||
}),
|
}),
|
||||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
|
@ -1138,7 +1138,27 @@ fn create_event_from_stream_token(
|
|||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
// replace the content with the tool calls if grammar is present
|
||||||
let (content, tool_calls) = if inner_using_tools {
|
let (content, tool_calls) = if inner_using_tools {
|
||||||
(None, Some(vec![stream_token.token.text.clone()]))
|
// Create a DeltaToolCall object
|
||||||
|
let delta_tool_call = DeltaToolCall {
|
||||||
|
index: 0, // Assuming this is the first tool call
|
||||||
|
id: "0".to_string(), // Generate a unique ID here
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: Some(stream_token.token.text.clone()), // Wrap in Some
|
||||||
|
arguments: stream_token.token.text.clone(), // Assuming the arguments are in the token text
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Serialize the DeltaToolCall into a JSON string
|
||||||
|
let tool_call_string = serde_json::to_string(&delta_tool_call).unwrap_or_else(|e| {
|
||||||
|
println!("Failed to serialize DeltaToolCall: {:?}", e);
|
||||||
|
String::new()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wrap the serialized tool call in a Vec and Option
|
||||||
|
let tool_calls = Some(vec![tool_call_string]);
|
||||||
|
|
||||||
|
(None, tool_calls)
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
let content = if !stream_token.token.special {
|
||||||
Some(stream_token.token.text.clone())
|
Some(stream_token.token.text.clone())
|
||||||
@ -1438,8 +1458,7 @@ pub(crate) async fn chat_completions(
|
|||||||
e
|
e
|
||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
println!("Arguments: {:?}", arguments);
|
|
||||||
println!("Arguments String: {:?}", arguments_string);
|
|
||||||
let tool_calls = vec![ToolCall {
|
let tool_calls = vec![ToolCall {
|
||||||
id: format!("{:09}", 0),
|
id: format!("{:09}", 0),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
|
Loading…
Reference in New Issue
Block a user