mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: avoid skip tool test and avoid empty tool prompts
This commit is contained in:
parent
1f72dcf062
commit
20db2c3db8
@ -757,7 +757,12 @@ class AsyncClient:
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
payload_data = (
|
||||
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
|
||||
)
|
||||
if payload_data == "[DONE]":
|
||||
break
|
||||
json_payload = json.loads(payload_data)
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
|
@ -36,6 +36,7 @@ tools = [
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -62,13 +63,13 @@ tools = [
|
||||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
"id": "0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_auto(
|
||||
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
|
||||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
"id": "0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
|
||||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
"id": "0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
async for response in responses:
|
||||
count += 1
|
||||
|
||||
assert count == 38
|
||||
assert count == 48
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=8,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.content is None
|
||||
assert responses.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
||||
},
|
||||
"description": None,
|
||||
"name": "notify_error",
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
|
||||
assert (
|
||||
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
||||
)
|
||||
assert responses == response_snapshot
|
||||
|
@ -840,7 +840,7 @@ pub(crate) struct ChatRequest {
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// A prompt to be appended before the tools
|
||||
#[serde(default = "default_tool_prompt")]
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||
@ -865,10 +865,8 @@ pub(crate) struct ChatRequest {
|
||||
pub guideline: Option<String>,
|
||||
}
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
Some(
|
||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
|
||||
)
|
||||
pub fn default_tool_prompt() -> String {
|
||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
|
@ -8,7 +8,7 @@ use crate::kserve::{
|
||||
kserve_model_metadata, kserve_model_metadata_ready,
|
||||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::ChatTokenizeResponse;
|
||||
use crate::{default_tool_prompt, ChatTokenizeResponse};
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
@ -1158,7 +1158,9 @@ async fn chat_completions(
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let logprobs = logprobs.unwrap_or(false);
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
|
Loading…
Reference in New Issue
Block a user