mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Make tool_call a list for streaming case
This commit is contained in:
parent
3495248d87
commit
f709466767
@ -727,7 +727,7 @@ pub(crate) struct ChatCompletionChoice {
|
||||
pub struct ToolCallDelta {
|
||||
#[schema(example = "assistant")]
|
||||
role: String,
|
||||
tool_calls: DeltaToolCall,
|
||||
tool_calls: Vec<DeltaToolCall>, // Changed to Vec<DeltaToolCall>
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
@ -770,15 +770,15 @@ impl ChatCompletionChunk {
|
||||
}),
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: DeltaToolCall {
|
||||
index: 0,
|
||||
tool_calls: tool_calls.iter().enumerate().map(|(index, tool_call)| DeltaToolCall {
|
||||
index: index as u32,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: tool_calls[0].to_string(),
|
||||
arguments: tool_call.to_string(),
|
||||
},
|
||||
},
|
||||
}).collect(),
|
||||
}),
|
||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
|
@ -1138,7 +1138,27 @@ fn create_event_from_stream_token(
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
// Create a DeltaToolCall object
|
||||
let delta_tool_call = DeltaToolCall {
|
||||
index: 0, // Assuming this is the first tool call
|
||||
id: "0".to_string(), // Generate a unique ID here
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: Some(stream_token.token.text.clone()), // Wrap in Some
|
||||
arguments: stream_token.token.text.clone(), // Assuming the arguments are in the token text
|
||||
},
|
||||
};
|
||||
|
||||
// Serialize the DeltaToolCall into a JSON string
|
||||
let tool_call_string = serde_json::to_string(&delta_tool_call).unwrap_or_else(|e| {
|
||||
println!("Failed to serialize DeltaToolCall: {:?}", e);
|
||||
String::new()
|
||||
});
|
||||
|
||||
// Wrap the serialized tool call in a Vec and Option
|
||||
let tool_calls = Some(vec![tool_call_string]);
|
||||
|
||||
(None, tool_calls)
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
@ -1438,8 +1458,7 @@ pub(crate) async fn chat_completions(
|
||||
e
|
||||
))
|
||||
})?;
|
||||
println!("Arguments: {:?}", arguments);
|
||||
println!("Arguments String: {:?}", arguments_string);
|
||||
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: format!("{:09}", 0),
|
||||
r#type: "function".to_string(),
|
||||
|
Loading…
Reference in New Issue
Block a user