From fa140a2eeb9585ac55a28b22fb93df05383676e5 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Wed, 9 Oct 2024 17:20:59 +0000 Subject: [PATCH] fix: always send event on error, avoid unwraps, refactor and improve tests --- ...tools_insufficient_information_stream.json | 20 ++ integration-tests/models/test_tools_llama.py | 52 +++- router/src/infer/mod.rs | 3 + router/src/server.rs | 236 ++++++++---------- 4 files changed, 180 insertions(+), 131 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json new file mode 100644 index 00000000..e60b8e80 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": " prompt", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1728494305, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 42a4ddf7..219967d8 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream( ) count = 0 + tool_calls_generated = "" + last_response = None async for response in responses: count += 1 + tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments + last_response = response + assert response.choices[0].delta.content is None + assert ( + tool_calls_generated + == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>' + ) assert count == 28 - assert response == response_snapshot + assert last_response == response_snapshot @pytest.mark.asyncio @@ -244,3 +253,44 @@ async def test_flash_llama_grammar_tools_insufficient_information( ) assert responses == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_insufficient_information_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="auto", + messages=[ + { + "role": "system", + "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=True, + ) + + count = 0 + content_generated = "" + last_response = None + async for response in responses: + count += 1 + content_generated += response.choices[0].delta.content + last_response = response + assert response.choices[0].delta.tool_calls is None + + assert count == 11 + print(content_generated) + assert ( + content_generated + == "There is no weather related function available to answer your prompt" + ) + assert last_response == response_snapshot diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 1c9d5620..896f4f43 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -355,6 +355,8 @@ pub enum InferError { MissingTemplateVariable(String), #[error("Tool error: {0}")] ToolError(String), + #[error("Stream event serialization error")] + StreamSerializationError(String), } impl InferError { @@ -368,6 +370,7 @@ impl InferError { InferError::TemplateError(_) => "template_error", InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::ToolError(_) => "tool_error", + InferError::StreamSerializationError(_) => "stream_serialization_error", } } } diff --git a/router/src/server.rs b/router/src/server.rs index f6df85d8..6f14044c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -459,10 +459,11 @@ async fn generate_stream( let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); while let Some(raw_event) = response_stream.next().await { - yield Ok(match raw_event { - Ok(token) => Event::default().json_data(token).unwrap(), - Err(err) => Event::from(err), - }); + yield Ok(raw_event.map_or_else(Event::from, |token| { + Event::default() + .json_data(token) + .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()) + })); } }; @@ -847,10 +848,7 @@ async fn completions( yield Ok(event); } - Err(_err) => { - let event = Event::default(); - yield Ok(event); - } + Err(err) => yield Ok(Event::from(err)), } } }; @@ -1226,18 +1224,29 @@ async fn chat_completions( // static values that will be returned in all cases let model_id = info.model_id.clone(); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); - let send_function_name = false; // TODO: fix to send function name - // switch on stream if stream { let (headers, response_stream) = generate_stream_internal(infer, compute_type, Json(generate_request), span).await; + // regex to match any function name + let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) { + Ok(regex) => regex, + Err(e) => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("Failed to compile regex: {}", e), + error_type: "regex".to_string(), + }), + )) + } + }; + let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); let mut buffer = Vec::new(); let mut json_buffer = String::new(); - // let mut content_buffer = String::new(); let mut state = if using_tools { StreamState::Buffering } else { @@ -1246,133 +1255,99 @@ async fn chat_completions( } }; let mut response_as_tool = using_tools; - - // Regex to match any function name - let function_regex = Regex::new(r#"\{"function":\{"_name":"([^"]+)""#).unwrap(); - while let Some(result) = response_stream.next().await { - match result { - Ok(stream_token) => { - let token_text = &stream_token.token.text.clone(); - - match state { - StreamState::Buffering => { - json_buffer.push_str(&token_text.replace(" ", "")); - buffer.push(stream_token); - - if let Some(captures) = function_regex.captures(&json_buffer) { - let function_name = captures[1].to_string(); - if function_name == "notify_error" { - state = StreamState::BufferTrailing; - response_as_tool = false; - buffer.clear(); - json_buffer.clear(); - } else { - state = StreamState::Content { - skip_close_quote: false, - }; - - if send_function_name { - // send a message with the the function name - let event = Event::default(); - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let chat_complete = CompletionType::ChatCompletionChunk( - ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - None, - Some(vec![function_name.clone()]), - current_time, - None, - None, - None, - ), - ); - - let event = event.json_data(chat_complete).unwrap(); - yield Ok(event); - } - - // send all the buffered messages - for stream_token in &buffer { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); - } + if let Ok(stream_token) = result { + let token_text = &stream_token.token.text.clone(); + match state { + StreamState::Buffering => { + json_buffer.push_str(&token_text.replace(" ", "")); + buffer.push(stream_token); + if let Some(captures) = function_regex.captures(&json_buffer) { + let function_name = captures[1].to_string(); + if function_name == "notify_error" { + state = StreamState::BufferTrailing; + response_as_tool = false; + buffer.clear(); + json_buffer.clear(); + } else { + state = StreamState::Content { + skip_close_quote: false, + }; + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); } } } - // if we skipped sending the buffer we need to avoid sending the following json key and quotes - StreamState::BufferTrailing => { - let infix_text = "\"error\":\""; - json_buffer.push_str(&token_text.replace(" ", "")); - if !json_buffer.contains(infix_text) { + } + // if we skipped sending the buffer we need to avoid sending the following json key and quotes + StreamState::BufferTrailing => { + let infix_text = "\"error\":\""; + json_buffer.push_str(&token_text.replace(" ", "")); + // keep capturing until we find the infix text + match json_buffer.find(infix_text) { + Some(error_index) => { + json_buffer = + json_buffer[error_index + infix_text.len()..].to_string(); + } + None => { continue; } - - let error_index = json_buffer.find(infix_text).unwrap(); - json_buffer = - json_buffer[error_index + infix_text.len()..].to_string(); - - if json_buffer.is_empty() { - let event = Event::default(); - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let chat_complete = - CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - Some(json_buffer.clone()), - None, - current_time, - None, - None, - None, - )); - - let event = event.json_data(chat_complete).unwrap(); - yield Ok(event); - } - - state = StreamState::Content { - skip_close_quote: true, - }; } - StreamState::Content { skip_close_quote } => { - if skip_close_quote && token_text.contains('"') { - break; - } - - // send the content - let event = create_event_from_stream_token( - &stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - - yield Ok::(event); + // if there is leftover text after removing the infix text, we need to send it + if !json_buffer.is_empty() { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + Some(json_buffer.clone()), + None, + current_time, + None, + None, + None, + )); + yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + })); } + // cleanup the buffers + buffer.clear(); + json_buffer.clear(); + state = StreamState::Content { + skip_close_quote: true, + }; + } + StreamState::Content { skip_close_quote } => { + if skip_close_quote && token_text.contains('"') { + break; + } + + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + + yield Ok::(event); } - } - Err(_err) => { - yield Ok::(Event::default()); - break; } } } @@ -2507,6 +2482,7 @@ impl From for (StatusCode, Json) { InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR, }; ( @@ -2684,7 +2660,7 @@ mod tests { ); assert!(result.is_ok()); - let (inputs, _grammar, using_tools) = result.unwrap(); + let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); assert_eq!(using_tools, true); assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); }