mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
feat: serialize function definition with serialize_as_string
This commit is contained in:
parent
7d852cde78
commit
aad1901aa5
@ -1508,26 +1508,7 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"FunctionCall": {
|
"FunctionDefinition": {
|
||||||
"type": "object",
|
|
||||||
"required": [
|
|
||||||
"name",
|
|
||||||
"arguments"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"arguments": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"description": {
|
|
||||||
"type": "string",
|
|
||||||
"nullable": true
|
|
||||||
},
|
|
||||||
"name": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"FunctionDefinitionDeprecated": {
|
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"name",
|
"name",
|
||||||
@ -1544,23 +1525,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"FunctionDefinition": {
|
|
||||||
"type": "object",
|
|
||||||
"required": [
|
|
||||||
"name",
|
|
||||||
"parameters"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"parameters": {},
|
|
||||||
"description": {
|
|
||||||
"type": "string",
|
|
||||||
"nullable": true
|
|
||||||
},
|
|
||||||
"name": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"FunctionName": {
|
"FunctionName": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@ -2263,14 +2227,7 @@
|
|||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"function": {
|
"function": {
|
||||||
"oneOf": [
|
"$ref": "#/components/schemas/FunctionDefinition"
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/FunctionDefinition"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/FunctionDefinitionDeprecated"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@ -2287,7 +2244,7 @@
|
|||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"function": {
|
"function": {
|
||||||
"$ref": "#/components/schemas/FunctionCall"
|
"$ref": "#/components/schemas/FunctionDefinition"
|
||||||
},
|
},
|
||||||
"id": {
|
"id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
@ -305,7 +305,7 @@ chat = client.chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(chat.choices[0].message.tool_calls)
|
print(chat.choices[0].message.tool_calls)
|
||||||
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')]
|
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')]
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -903,44 +903,7 @@ mod tests {
|
|||||||
let tool_prompt = "This default prompt will be used".to_string();
|
let tool_prompt = "This default prompt will be used".to_string();
|
||||||
let tools_and_prompt = Some((tools, tool_prompt));
|
let tools_and_prompt = Some((tools, tool_prompt));
|
||||||
let result = ct.apply(msgs, tools_and_prompt);
|
let result = ct.apply(msgs, tools_and_prompt);
|
||||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
|
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":\"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\"}}]\nThis default prompt will be used [/INST]".to_string();
|
||||||
assert_eq!(result.unwrap(), expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_template_with_default_tool_template_arguments_deprecated() {
|
|
||||||
let ct = ChatTemplate::new(
|
|
||||||
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
|
|
||||||
Some(TokenizerConfigToken::String("<s>".to_string())),
|
|
||||||
Some(TokenizerConfigToken::String("</s>".to_string())),
|
|
||||||
);
|
|
||||||
|
|
||||||
// convert TextMessage to Message
|
|
||||||
let msgs: Vec<Message> = vec![
|
|
||||||
Message {
|
|
||||||
name: None,
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: MessageContent::SingleText(
|
|
||||||
"I'd like to show off how chat templating works!".to_string(),
|
|
||||||
),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
name: None,
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
name: None,
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: MessageContent::SingleText("Just testing".to_string()),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","arguments": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
|
||||||
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
|
||||||
let tool_prompt = "This default prompt will be used".to_string();
|
|
||||||
let tools_and_prompt = Some((tools, tool_prompt));
|
|
||||||
let result = ct.apply(msgs, tools_and_prompt);
|
|
||||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
|
|
||||||
assert_eq!(result.unwrap(), expected);
|
assert_eq!(result.unwrap(), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -974,7 +937,7 @@ mod tests {
|
|||||||
let tool_prompt = "This default prompt will be used".to_string();
|
let tool_prompt = "This default prompt will be used".to_string();
|
||||||
let tools_and_prompt = Some((tools, tool_prompt));
|
let tools_and_prompt = Some((tools, tool_prompt));
|
||||||
let result = ct.apply(msgs, tools_and_prompt);
|
let result = ct.apply(msgs, tools_and_prompt);
|
||||||
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||||
assert_eq!(result.unwrap(), expected);
|
assert_eq!(result.unwrap(), expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ impl ToolGrammar {
|
|||||||
description: Some(
|
description: Some(
|
||||||
"Open ended response with no specific tool selected".to_string(),
|
"Open ended response with no specific tool selected".to_string(),
|
||||||
),
|
),
|
||||||
parameters: json!({
|
arguments: json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {
|
"content": {
|
||||||
@ -83,7 +83,7 @@ impl ToolGrammar {
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Value::Object(args) = func.parameters {
|
if let Value::Object(args) = func.arguments {
|
||||||
if let Some(Value::Object(props)) = args.get("properties") {
|
if let Some(Value::Object(props)) = args.get("properties") {
|
||||||
properties.extend(props.clone());
|
properties.extend(props.clone());
|
||||||
}
|
}
|
||||||
|
@ -742,11 +742,11 @@ pub(crate) struct DeltaToolCall {
|
|||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
pub function: FunctionCallChunk,
|
pub function: Function,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
pub(crate) struct FunctionCallChunk {
|
pub(crate) struct Function {
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
pub arguments: String,
|
pub arguments: String,
|
||||||
}
|
}
|
||||||
@ -757,7 +757,7 @@ impl ChatCompletionChunk {
|
|||||||
model: String,
|
model: String,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
delta: Option<String>,
|
delta: Option<String>,
|
||||||
tool_calls: Option<FunctionCallChunk>,
|
tool_calls: Option<Vec<String>>,
|
||||||
created: u64,
|
created: u64,
|
||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
@ -774,7 +774,10 @@ impl ChatCompletionChunk {
|
|||||||
index: 0,
|
index: 0,
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: tool_calls,
|
function: Function {
|
||||||
|
name: None,
|
||||||
|
arguments: tool_calls[0].to_string(),
|
||||||
|
},
|
||||||
}],
|
}],
|
||||||
}),
|
}),
|
||||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
@ -1130,14 +1133,15 @@ pub(crate) struct FunctionDefinition {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
#[serde(alias = "arguments")]
|
#[serde(alias = "parameters", serialize_with = "serialize_as_string")]
|
||||||
pub parameters: serde_json::Value,
|
pub arguments: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
|
fn serialize_as_string<S>(value: &serde_json::Value, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
pub(crate) struct FunctionCall {
|
where
|
||||||
pub name: String,
|
S: serde::Serializer,
|
||||||
pub arguments: String,
|
{
|
||||||
|
serializer.serialize_str(&value.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
@ -1163,7 +1167,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
pub(crate) struct ToolCall {
|
pub(crate) struct ToolCall {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
pub function: FunctionCall,
|
pub function: FunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
@ -1682,19 +1686,19 @@ mod tests {
|
|||||||
tool_calls: vec![ToolCall {
|
tool_calls: vec![ToolCall {
|
||||||
id: "0".to_string(),
|
id: "0".to_string(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionCall {
|
function: FunctionDefinition {
|
||||||
|
description: None,
|
||||||
name: "myfn".to_string(),
|
name: "myfn".to_string(),
|
||||||
arguments: json!({
|
arguments: json!({
|
||||||
"format": "csv"
|
"format": "csv"
|
||||||
})
|
}),
|
||||||
.to_string(),
|
|
||||||
},
|
},
|
||||||
}],
|
}],
|
||||||
});
|
});
|
||||||
let serialized = serde_json::to_string(&message).unwrap();
|
let serialized = serde_json::to_string(&message).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
serialized,
|
serialized,
|
||||||
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
|
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,9 +12,10 @@ use crate::sagemaker::{
|
|||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::vertex::vertex_compatibility;
|
use crate::vertex::vertex_compatibility;
|
||||||
|
use crate::ChatTokenizeResponse;
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
||||||
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
|
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
|
||||||
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
|
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
|
||||||
@ -24,9 +25,8 @@ use crate::{
|
|||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||||
};
|
};
|
||||||
use crate::{ChatTokenizeResponse, FunctionCall};
|
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||||
use crate::{ModelInfo, ModelsInfo};
|
use crate::{ModelInfo, ModelsInfo};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
@ -1117,7 +1117,6 @@ pub(crate) async fn completions(
|
|||||||
enum StreamState {
|
enum StreamState {
|
||||||
Buffering,
|
Buffering,
|
||||||
BufferTrailing,
|
BufferTrailing,
|
||||||
Arguments,
|
|
||||||
Content { skip_close_quote: bool },
|
Content { skip_close_quote: bool },
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1127,7 +1126,6 @@ fn create_event_from_stream_token(
|
|||||||
logprobs: bool,
|
logprobs: bool,
|
||||||
stream_options: Option<StreamOptions>,
|
stream_options: Option<StreamOptions>,
|
||||||
inner_using_tools: bool,
|
inner_using_tools: bool,
|
||||||
partial_call: Option<FunctionCallChunk>,
|
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
model_id: String,
|
model_id: String,
|
||||||
) -> Event {
|
) -> Event {
|
||||||
@ -1143,16 +1141,7 @@ fn create_event_from_stream_token(
|
|||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
// replace the content with the tool calls if grammar is present
|
||||||
let (content, tool_calls) = if inner_using_tools {
|
let (content, tool_calls) = if inner_using_tools {
|
||||||
match partial_call {
|
(None, Some(vec![stream_token.token.text.clone()]))
|
||||||
Some(partial_call) => (None, Some(partial_call)),
|
|
||||||
None => (
|
|
||||||
None,
|
|
||||||
Some(FunctionCallChunk {
|
|
||||||
name: None,
|
|
||||||
arguments: stream_token.token.text.clone(),
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
let content = if !stream_token.token.special {
|
||||||
Some(stream_token.token.text.clone())
|
Some(stream_token.token.text.clone())
|
||||||
@ -1269,7 +1258,7 @@ pub(crate) async fn chat_completions(
|
|||||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||||
|
|
||||||
// regex to match any function name
|
// regex to match any function name
|
||||||
let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) {
|
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||||
Ok(regex) => regex,
|
Ok(regex) => regex,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Err((
|
return Err((
|
||||||
@ -1284,6 +1273,7 @@ pub(crate) async fn chat_completions(
|
|||||||
|
|
||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
|
let mut buffer = Vec::new();
|
||||||
let mut json_buffer = String::new();
|
let mut json_buffer = String::new();
|
||||||
let mut state = if using_tools {
|
let mut state = if using_tools {
|
||||||
StreamState::Buffering
|
StreamState::Buffering
|
||||||
@ -1300,27 +1290,30 @@ pub(crate) async fn chat_completions(
|
|||||||
match state {
|
match state {
|
||||||
StreamState::Buffering => {
|
StreamState::Buffering => {
|
||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||||
if let Some(captures) = function_name_regex.captures(&json_buffer) {
|
buffer.push(stream_token);
|
||||||
|
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||||
let function_name = captures[1].to_string();
|
let function_name = captures[1].to_string();
|
||||||
if function_name == "no_tool" {
|
if function_name == "no_tool" {
|
||||||
state = StreamState::BufferTrailing;
|
state = StreamState::BufferTrailing;
|
||||||
response_as_tool = false;
|
response_as_tool = false;
|
||||||
|
buffer.clear();
|
||||||
json_buffer.clear();
|
json_buffer.clear();
|
||||||
} else {
|
} else {
|
||||||
state = StreamState::Arguments;
|
state = StreamState::Content {
|
||||||
let event = create_event_from_stream_token(
|
skip_close_quote: false,
|
||||||
&stream_token,
|
};
|
||||||
logprobs,
|
// send all the buffered messages
|
||||||
stream_options.clone(),
|
for stream_token in &buffer {
|
||||||
response_as_tool,
|
let event = create_event_from_stream_token(
|
||||||
Some(FunctionCallChunk {
|
stream_token,
|
||||||
name: Some(function_name),
|
logprobs,
|
||||||
arguments: "{".to_string()
|
stream_options.clone(),
|
||||||
}),
|
response_as_tool,
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
);
|
);
|
||||||
yield Ok::<Event, Infallible>(event);
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1361,32 +1354,12 @@ pub(crate) async fn chat_completions(
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
// cleanup the buffers
|
// cleanup the buffers
|
||||||
|
buffer.clear();
|
||||||
json_buffer.clear();
|
json_buffer.clear();
|
||||||
state = StreamState::Content {
|
state = StreamState::Content {
|
||||||
skip_close_quote: true,
|
skip_close_quote: true,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
StreamState::Arguments => {
|
|
||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
|
||||||
|
|
||||||
// If we are at the end of the json we can stop
|
|
||||||
let function: Result<serde::de::IgnoredAny, _> = serde_json::from_str(&json_buffer);
|
|
||||||
if let Ok(_) = function {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the content
|
|
||||||
let event = create_event_from_stream_token(
|
|
||||||
&stream_token,
|
|
||||||
logprobs,
|
|
||||||
stream_options.clone(),
|
|
||||||
response_as_tool,
|
|
||||||
None,
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
model_id.clone(),
|
|
||||||
);
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
}
|
|
||||||
StreamState::Content { skip_close_quote } => {
|
StreamState::Content { skip_close_quote } => {
|
||||||
if skip_close_quote && token_text.contains('"') {
|
if skip_close_quote && token_text.contains('"') {
|
||||||
break;
|
break;
|
||||||
@ -1398,7 +1371,6 @@ pub(crate) async fn chat_completions(
|
|||||||
logprobs,
|
logprobs,
|
||||||
stream_options.clone(),
|
stream_options.clone(),
|
||||||
response_as_tool,
|
response_as_tool,
|
||||||
None,
|
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
);
|
);
|
||||||
@ -1466,9 +1438,10 @@ pub(crate) async fn chat_completions(
|
|||||||
let tool_calls = vec![ToolCall {
|
let tool_calls = vec![ToolCall {
|
||||||
id: "0".to_string(),
|
id: "0".to_string(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionCall {
|
function: FunctionDefinition {
|
||||||
|
description: None,
|
||||||
name,
|
name,
|
||||||
arguments: arguments.to_string(),
|
arguments,
|
||||||
},
|
},
|
||||||
}];
|
}];
|
||||||
(Some(tool_calls), None)
|
(Some(tool_calls), None)
|
||||||
@ -1599,8 +1572,8 @@ StreamOptions,
|
|||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
Tool,
|
Tool,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
Function,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
FunctionCall,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ChatTokenizeResponse,
|
ChatTokenizeResponse,
|
||||||
|
Loading…
Reference in New Issue
Block a user