mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: refactor away prepare_chat_input and improve tool grammar apply control flow
This commit is contained in:
parent
b2db1075e4
commit
daa1c6280a
@ -58,7 +58,7 @@ impl ToolGrammar {
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
|
||||
ChatCompletionToolChoiceOption::NoTool => Vec::with_capacity(0),
|
||||
};
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
@ -107,18 +107,22 @@ impl ToolGrammar {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tool_schema = JsonSchemaTool {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
let tool_schema = if tools_to_use.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(JsonSchemaTool {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
})
|
||||
};
|
||||
|
||||
Ok((tools_to_use, Some(tool_schema)))
|
||||
Ok((tools_to_use, tool_schema))
|
||||
}
|
||||
}
|
||||
|
@ -12,8 +12,8 @@ mod sagemaker;
|
||||
pub mod usage_stats;
|
||||
mod vertex;
|
||||
|
||||
use crate::infer::tool_grammar::ToolGrammar;
|
||||
use crate::infer::{Infer, InferError};
|
||||
use crate::server::prepare_chat_input;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::IntoPyDict;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -947,22 +947,46 @@ impl ChatRequest {
|
||||
other => (true, other),
|
||||
};
|
||||
// unwrap or default (use "auto" if tools are present, and "none" if not)
|
||||
let choice = tool_choice.unwrap_or_else(|| {
|
||||
let tool_choice = tool_choice.unwrap_or_else(|| {
|
||||
if tools.is_some() {
|
||||
ChatCompletionToolChoiceOption::Auto
|
||||
} else {
|
||||
ChatCompletionToolChoiceOption::NoTool
|
||||
}
|
||||
});
|
||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||
infer,
|
||||
response_format,
|
||||
tools,
|
||||
choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
)?;
|
||||
|
||||
if response_format.is_some() && tools.is_some() {
|
||||
return Err(InferError::ToolError(
|
||||
"Grammar and tools are mutually exclusive".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let (inputs, grammar, using_tools) = match response_format {
|
||||
Some(format) => {
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
(inputs, Some(format), false)
|
||||
}
|
||||
None => {
|
||||
if let Some(tools) = tools {
|
||||
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
|
||||
|
||||
let grammar = tool_schema
|
||||
.as_ref()
|
||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||
|
||||
let inputs: String = infer.apply_chat_template(
|
||||
guideline,
|
||||
messages,
|
||||
Some((updated_tools, tool_prompt)),
|
||||
)?;
|
||||
(inputs, grammar, tool_schema.is_some())
|
||||
} else {
|
||||
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
(inputs, None, false)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok((
|
||||
GenerateRequest {
|
||||
@ -1239,6 +1263,7 @@ pub(crate) enum OutputMessage {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub(crate) struct GenerateRequest {
|
||||
#[schema(example = "My name is Olivier and I")]
|
||||
pub inputs: String,
|
||||
@ -1719,4 +1744,109 @@ mod tests {
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_into_generate_with_tools_and_template() {
|
||||
use crate::infer::Backend;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::ChatTemplateVersions;
|
||||
use crate::HubTokenizerConfig;
|
||||
use crate::TokenizerConfigToken;
|
||||
use crate::Tool;
|
||||
|
||||
use core::future::Future;
|
||||
use std::pin::Pin;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
// Mock Backend to avoid network requests. This is never used since we only test the conversion. It is mocked to satisfy the Backend trait.
|
||||
struct MockBackend;
|
||||
impl Backend for MockBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
_: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>
|
||||
{
|
||||
unimplemented!("Never called in this test");
|
||||
}
|
||||
fn health<'a, 't>(&'a self, _: bool) -> Pin<Box<dyn Future<Output = bool> + Send + 't>>
|
||||
where
|
||||
'a: 't,
|
||||
Self: 't,
|
||||
{
|
||||
unimplemented!("Never called in this test");
|
||||
}
|
||||
}
|
||||
|
||||
let mut tokenizer_config = HubTokenizerConfig::default();
|
||||
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
|
||||
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
|
||||
tokenizer_config.chat_template = Some(
|
||||
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
|
||||
);
|
||||
|
||||
let infer = Infer::new(
|
||||
MockBackend {}, // never used; just to satisfy Infer::new signature
|
||||
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
|
||||
1,
|
||||
tokenizer_config,
|
||||
HubProcessorConfig::default(),
|
||||
);
|
||||
|
||||
let tools = Some(vec![Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: Some("Get the current weather".to_string()),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}),
|
||||
},
|
||||
}]);
|
||||
|
||||
let request = ChatRequest {
|
||||
model: None,
|
||||
max_tokens: None,
|
||||
logit_bias: None,
|
||||
logprobs: None,
|
||||
n: None,
|
||||
messages: vec![Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText(
|
||||
"What is the weather like in New York?".to_string(),
|
||||
),
|
||||
}],
|
||||
seed: None,
|
||||
stop: None,
|
||||
stream: false,
|
||||
tools: tools,
|
||||
tool_choice: None,
|
||||
tool_prompt: Some("Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.".to_string()),
|
||||
temperature: None,
|
||||
response_format: None,
|
||||
guideline: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
top_p: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
};
|
||||
|
||||
let (generate, using_tools) = request.try_into_generate(&infer).unwrap();
|
||||
assert_eq!(using_tools, true);
|
||||
assert_eq!(generate.inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
/// HTTP Server logic
|
||||
use crate::config::Config;
|
||||
use crate::infer::tool_grammar::ToolGrammar;
|
||||
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
||||
#[cfg(feature = "kserve")]
|
||||
use crate::kserve::{
|
||||
@ -2514,160 +2513,3 @@ pub enum WebServerError {
|
||||
#[error("Axum error: {0}")]
|
||||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
||||
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||
|
||||
pub(crate) fn prepare_chat_input(
|
||||
infer: &Infer,
|
||||
response_format: Option<GrammarType>,
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: ChatCompletionToolChoiceOption,
|
||||
tool_prompt: &str,
|
||||
guideline: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<PreparedInput, InferError> {
|
||||
if response_format.is_some() && tools.is_some() {
|
||||
return Err(InferError::ToolError(
|
||||
"Grammar and tools are mutually exclusive".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// when response_format is set, tools are not included when applying the chat template to generate inputs
|
||||
if let Some(format) = response_format {
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
return Ok((inputs, Some(format), false));
|
||||
}
|
||||
|
||||
// when no response_format is set and tools are included, apply the chat template with the tools
|
||||
// to generate inputs
|
||||
if let Some(tools) = tools {
|
||||
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
|
||||
|
||||
let grammar = tool_schema
|
||||
.as_ref()
|
||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||
|
||||
let inputs: String = infer.apply_chat_template(
|
||||
guideline,
|
||||
messages,
|
||||
Some((updated_tools, tool_prompt.into())),
|
||||
)?;
|
||||
return Ok((inputs, grammar, tool_schema.is_some()));
|
||||
}
|
||||
|
||||
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
Ok((inputs, None, false))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ChatTemplateVersions;
|
||||
use crate::HubTokenizerConfig;
|
||||
use crate::TokenizerConfigToken;
|
||||
use crate::Tool;
|
||||
|
||||
use crate::tests::get_tokenizer;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_prepare_chat_input() {
|
||||
// Mock Backend to avoid network requests
|
||||
struct MockBackend;
|
||||
|
||||
impl Backend for MockBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
_request: crate::validation::ValidGenerateRequest,
|
||||
) -> Result<
|
||||
tokio_stream::wrappers::UnboundedReceiverStream<
|
||||
Result<InferStreamResponse, InferError>,
|
||||
>,
|
||||
InferError,
|
||||
> {
|
||||
unimplemented!("Never called in this test");
|
||||
}
|
||||
fn health<'a, 'async_trait>(
|
||||
&'a self,
|
||||
_current_health: bool,
|
||||
) -> core::pin::Pin<
|
||||
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
|
||||
>
|
||||
where
|
||||
'a: 'async_trait,
|
||||
Self: 'async_trait,
|
||||
{
|
||||
unimplemented!("Never called in this test");
|
||||
}
|
||||
}
|
||||
|
||||
let backend = MockBackend {};
|
||||
|
||||
let mut tokenizer_config = HubTokenizerConfig::default();
|
||||
|
||||
// mock tokenizer config values
|
||||
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
|
||||
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
|
||||
tokenizer_config.chat_template = Some(
|
||||
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
|
||||
);
|
||||
|
||||
let tokenizer = get_tokenizer();
|
||||
|
||||
let infer = Infer::new(
|
||||
backend,
|
||||
Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false),
|
||||
1,
|
||||
tokenizer_config,
|
||||
HubProcessorConfig::default(),
|
||||
);
|
||||
let response_format = None;
|
||||
let tools = Some(vec![Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: Some("Get the current weather".to_string()),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}),
|
||||
},
|
||||
}]);
|
||||
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
|
||||
let guideline = None;
|
||||
let messages = vec![Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText(
|
||||
"What is the weather like in New York?".to_string(),
|
||||
),
|
||||
}];
|
||||
|
||||
let result = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
ChatCompletionToolChoiceOption::Auto,
|
||||
tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
|
||||
assert_eq!(using_tools, true);
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user