mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: always send event on error, avoid unwraps, refactor and improve tests
This commit is contained in:
parent
7d2aa27161
commit
fa140a2eeb
@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " prompt",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1728494305,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.3.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
|
tool_calls_generated = ""
|
||||||
|
last_response = None
|
||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
|
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
||||||
|
last_response = response
|
||||||
|
assert response.choices[0].delta.content is None
|
||||||
|
|
||||||
|
assert (
|
||||||
|
tool_calls_generated
|
||||||
|
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
|
||||||
|
)
|
||||||
assert count == 28
|
assert count == 28
|
||||||
assert response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -244,3 +253,44 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=24,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Tell me a story about 3 sea creatures",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
content_generated = ""
|
||||||
|
last_response = None
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
content_generated += response.choices[0].delta.content
|
||||||
|
last_response = response
|
||||||
|
assert response.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
|
assert count == 11
|
||||||
|
print(content_generated)
|
||||||
|
assert (
|
||||||
|
content_generated
|
||||||
|
== "There is no weather related function available to answer your prompt"
|
||||||
|
)
|
||||||
|
assert last_response == response_snapshot
|
||||||
|
@ -355,6 +355,8 @@ pub enum InferError {
|
|||||||
MissingTemplateVariable(String),
|
MissingTemplateVariable(String),
|
||||||
#[error("Tool error: {0}")]
|
#[error("Tool error: {0}")]
|
||||||
ToolError(String),
|
ToolError(String),
|
||||||
|
#[error("Stream event serialization error")]
|
||||||
|
StreamSerializationError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferError {
|
impl InferError {
|
||||||
@ -368,6 +370,7 @@ impl InferError {
|
|||||||
InferError::TemplateError(_) => "template_error",
|
InferError::TemplateError(_) => "template_error",
|
||||||
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
||||||
InferError::ToolError(_) => "tool_error",
|
InferError::ToolError(_) => "tool_error",
|
||||||
|
InferError::StreamSerializationError(_) => "stream_serialization_error",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -459,10 +459,11 @@ async fn generate_stream(
|
|||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
while let Some(raw_event) = response_stream.next().await {
|
while let Some(raw_event) = response_stream.next().await {
|
||||||
yield Ok(match raw_event {
|
yield Ok(raw_event.map_or_else(Event::from, |token| {
|
||||||
Ok(token) => Event::default().json_data(token).unwrap(),
|
Event::default()
|
||||||
Err(err) => Event::from(err),
|
.json_data(token)
|
||||||
});
|
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -847,10 +848,7 @@ async fn completions(
|
|||||||
|
|
||||||
yield Ok(event);
|
yield Ok(event);
|
||||||
}
|
}
|
||||||
Err(_err) => {
|
Err(err) => yield Ok(Event::from(err)),
|
||||||
let event = Event::default();
|
|
||||||
yield Ok(event);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1226,18 +1224,29 @@ async fn chat_completions(
|
|||||||
// static values that will be returned in all cases
|
// static values that will be returned in all cases
|
||||||
let model_id = info.model_id.clone();
|
let model_id = info.model_id.clone();
|
||||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
let send_function_name = false; // TODO: fix to send function name
|
|
||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if stream {
|
if stream {
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||||
|
|
||||||
|
// regex to match any function name
|
||||||
|
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||||
|
Ok(regex) => regex,
|
||||||
|
Err(e) => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: format!("Failed to compile regex: {}", e),
|
||||||
|
error_type: "regex".to_string(),
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
let mut json_buffer = String::new();
|
let mut json_buffer = String::new();
|
||||||
// let mut content_buffer = String::new();
|
|
||||||
let mut state = if using_tools {
|
let mut state = if using_tools {
|
||||||
StreamState::Buffering
|
StreamState::Buffering
|
||||||
} else {
|
} else {
|
||||||
@ -1246,133 +1255,99 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let mut response_as_tool = using_tools;
|
let mut response_as_tool = using_tools;
|
||||||
|
|
||||||
// Regex to match any function name
|
|
||||||
let function_regex = Regex::new(r#"\{"function":\{"_name":"([^"]+)""#).unwrap();
|
|
||||||
|
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
match result {
|
if let Ok(stream_token) = result {
|
||||||
Ok(stream_token) => {
|
let token_text = &stream_token.token.text.clone();
|
||||||
let token_text = &stream_token.token.text.clone();
|
match state {
|
||||||
|
StreamState::Buffering => {
|
||||||
match state {
|
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||||
StreamState::Buffering => {
|
buffer.push(stream_token);
|
||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||||
buffer.push(stream_token);
|
let function_name = captures[1].to_string();
|
||||||
|
if function_name == "notify_error" {
|
||||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
state = StreamState::BufferTrailing;
|
||||||
let function_name = captures[1].to_string();
|
response_as_tool = false;
|
||||||
if function_name == "notify_error" {
|
buffer.clear();
|
||||||
state = StreamState::BufferTrailing;
|
json_buffer.clear();
|
||||||
response_as_tool = false;
|
} else {
|
||||||
buffer.clear();
|
state = StreamState::Content {
|
||||||
json_buffer.clear();
|
skip_close_quote: false,
|
||||||
} else {
|
};
|
||||||
state = StreamState::Content {
|
// send all the buffered messages
|
||||||
skip_close_quote: false,
|
for stream_token in &buffer {
|
||||||
};
|
let event = create_event_from_stream_token(
|
||||||
|
stream_token,
|
||||||
if send_function_name {
|
logprobs,
|
||||||
// send a message with the the function name
|
stream_options.clone(),
|
||||||
let event = Event::default();
|
response_as_tool,
|
||||||
let current_time = std::time::SystemTime::now()
|
system_fingerprint.clone(),
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
model_id.clone(),
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
);
|
||||||
.as_secs();
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
|
||||||
let chat_complete = CompletionType::ChatCompletionChunk(
|
|
||||||
ChatCompletionChunk::new(
|
|
||||||
model_id.clone(),
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
None,
|
|
||||||
Some(vec![function_name.clone()]),
|
|
||||||
current_time,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
|
|
||||||
let event = event.json_data(chat_complete).unwrap();
|
|
||||||
yield Ok(event);
|
|
||||||
}
|
|
||||||
|
|
||||||
// send all the buffered messages
|
|
||||||
for stream_token in &buffer {
|
|
||||||
let event = create_event_from_stream_token(
|
|
||||||
stream_token,
|
|
||||||
logprobs,
|
|
||||||
stream_options.clone(),
|
|
||||||
response_as_tool,
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
model_id.clone(),
|
|
||||||
);
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
}
|
||||||
StreamState::BufferTrailing => {
|
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||||
let infix_text = "\"error\":\"";
|
StreamState::BufferTrailing => {
|
||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
let infix_text = "\"error\":\"";
|
||||||
if !json_buffer.contains(infix_text) {
|
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||||
|
// keep capturing until we find the infix text
|
||||||
|
match json_buffer.find(infix_text) {
|
||||||
|
Some(error_index) => {
|
||||||
|
json_buffer =
|
||||||
|
json_buffer[error_index + infix_text.len()..].to_string();
|
||||||
|
}
|
||||||
|
None => {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let error_index = json_buffer.find(infix_text).unwrap();
|
|
||||||
json_buffer =
|
|
||||||
json_buffer[error_index + infix_text.len()..].to_string();
|
|
||||||
|
|
||||||
if json_buffer.is_empty() {
|
|
||||||
let event = Event::default();
|
|
||||||
let current_time = std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
|
||||||
.as_secs();
|
|
||||||
|
|
||||||
let chat_complete =
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
|
||||||
model_id.clone(),
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
Some(json_buffer.clone()),
|
|
||||||
None,
|
|
||||||
current_time,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
|
|
||||||
let event = event.json_data(chat_complete).unwrap();
|
|
||||||
yield Ok(event);
|
|
||||||
}
|
|
||||||
|
|
||||||
state = StreamState::Content {
|
|
||||||
skip_close_quote: true,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
StreamState::Content { skip_close_quote } => {
|
// if there is leftover text after removing the infix text, we need to send it
|
||||||
if skip_close_quote && token_text.contains('"') {
|
if !json_buffer.is_empty() {
|
||||||
break;
|
let event = Event::default();
|
||||||
}
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
// send the content
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
let event = create_event_from_stream_token(
|
.as_secs();
|
||||||
&stream_token,
|
let chat_complete =
|
||||||
logprobs,
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||||
stream_options.clone(),
|
model_id.clone(),
|
||||||
response_as_tool,
|
system_fingerprint.clone(),
|
||||||
system_fingerprint.clone(),
|
Some(json_buffer.clone()),
|
||||||
model_id.clone(),
|
None,
|
||||||
);
|
current_time,
|
||||||
|
None,
|
||||||
yield Ok::<Event, Infallible>(event);
|
None,
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
|
InferError::StreamSerializationError(e.to_string()).into()
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
// cleanup the buffers
|
||||||
|
buffer.clear();
|
||||||
|
json_buffer.clear();
|
||||||
|
state = StreamState::Content {
|
||||||
|
skip_close_quote: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
StreamState::Content { skip_close_quote } => {
|
||||||
|
if skip_close_quote && token_text.contains('"') {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the content
|
||||||
|
let event = create_event_from_stream_token(
|
||||||
|
&stream_token,
|
||||||
|
logprobs,
|
||||||
|
stream_options.clone(),
|
||||||
|
response_as_tool,
|
||||||
|
system_fingerprint.clone(),
|
||||||
|
model_id.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
Err(_err) => {
|
|
||||||
yield Ok::<Event, Infallible>(Event::default());
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2507,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
@ -2684,7 +2660,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
let (inputs, _grammar, using_tools) = result.unwrap();
|
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
|
||||||
assert_eq!(using_tools, true);
|
assert_eq!(using_tools, true);
|
||||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user