mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
feat: avoid skip tool test and avoid empty tool prompts
This commit is contained in:
parent
1f72dcf062
commit
20db2c3db8
@ -757,7 +757,12 @@ class AsyncClient:
|
|||||||
continue
|
continue
|
||||||
payload = byte_payload.decode("utf-8")
|
payload = byte_payload.decode("utf-8")
|
||||||
if payload.startswith("data:"):
|
if payload.startswith("data:"):
|
||||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
payload_data = (
|
||||||
|
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
|
||||||
|
)
|
||||||
|
if payload_data == "[DONE]":
|
||||||
|
break
|
||||||
|
json_payload = json.loads(payload_data)
|
||||||
try:
|
try:
|
||||||
response = ChatCompletionChunk(**json_payload)
|
response = ChatCompletionChunk(**json_payload)
|
||||||
yield response
|
yield response
|
||||||
|
@ -36,6 +36,7 @@ tools = [
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "format"],
|
"required": ["location", "format"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -62,13 +63,13 @@ tools = [
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "format", "num_days"],
|
"required": ["location", "format", "num_days"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||||
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
presence_penalty=-1.1,
|
temperature=0.0,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_auto(
|
async def test_flash_llama_grammar_tools_auto(
|
||||||
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_choice(
|
async def test_flash_llama_grammar_tools_choice(
|
||||||
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="get_current_weather",
|
tool_choice="get_current_weather",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_stream(
|
async def test_flash_llama_grammar_tools_stream(
|
||||||
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="get_current_weather",
|
tool_choice="get_current_weather",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
assert count == 38
|
assert count == 48
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_insufficient_information(
|
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||||||
):
|
):
|
||||||
responses = await flash_llama_grammar_tools.chat(
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=8,
|
seed=24,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert responses.choices[0].message.content is None
|
assert responses.choices[0].message.content is None
|
||||||
assert responses.choices[0].message.tool_calls == [
|
assert (
|
||||||
{
|
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
||||||
"function": {
|
)
|
||||||
"arguments": {
|
|
||||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
|
||||||
},
|
|
||||||
"description": None,
|
|
||||||
"name": "notify_error",
|
|
||||||
},
|
|
||||||
"id": 0,
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -840,7 +840,7 @@ pub(crate) struct ChatRequest {
|
|||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
/// A prompt to be appended before the tools
|
/// A prompt to be appended before the tools
|
||||||
#[serde(default = "default_tool_prompt")]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
nullable = true,
|
nullable = true,
|
||||||
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||||
@ -865,10 +865,8 @@ pub(crate) struct ChatRequest {
|
|||||||
pub guideline: Option<String>,
|
pub guideline: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
pub fn default_tool_prompt() -> String {
|
||||||
Some(
|
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||||
|
@ -8,7 +8,7 @@ use crate::kserve::{
|
|||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::{default_tool_prompt, ChatTokenizeResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
@ -1158,7 +1158,9 @@ async fn chat_completions(
|
|||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
let logprobs = logprobs.unwrap_or(false);
|
let logprobs = logprobs.unwrap_or(false);
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(default_tool_prompt);
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
// enable greedy only when temperature is 0
|
// enable greedy only when temperature is 0
|
||||||
let (do_sample, temperature) = match temperature {
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Loading…
Reference in New Issue
Block a user