fix: always send event on error, avoid unwraps, refactor and improve tests

This commit is contained in:
David Holtz 2024-10-09 17:20:59 +00:00
parent 7d2aa27161
commit fa140a2eeb
4 changed files with 180 additions and 131 deletions

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " prompt",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728494305,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
) )
count = 0 count = 0
tool_calls_generated = ""
last_response = None
async for response in responses: async for response in responses:
count += 1 count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28 assert count == 28
assert response == response_snapshot assert last_response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@ -244,3 +253,44 @@ async def test_flash_llama_grammar_tools_insufficient_information(
) )
assert responses == response_snapshot assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 11
print(content_generated)
assert (
content_generated
== "There is no weather related function available to answer your prompt"
)
assert last_response == response_snapshot

View File

@ -355,6 +355,8 @@ pub enum InferError {
MissingTemplateVariable(String), MissingTemplateVariable(String),
#[error("Tool error: {0}")] #[error("Tool error: {0}")]
ToolError(String), ToolError(String),
#[error("Stream event serialization error")]
StreamSerializationError(String),
} }
impl InferError { impl InferError {
@ -368,6 +370,7 @@ impl InferError {
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",
InferError::StreamSerializationError(_) => "stream_serialization_error",
} }
} }
} }

View File

@ -459,10 +459,11 @@ async fn generate_stream(
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
while let Some(raw_event) = response_stream.next().await { while let Some(raw_event) = response_stream.next().await {
yield Ok(match raw_event { yield Ok(raw_event.map_or_else(Event::from, |token| {
Ok(token) => Event::default().json_data(token).unwrap(), Event::default()
Err(err) => Event::from(err), .json_data(token)
}); .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
}));
} }
}; };
@ -847,10 +848,7 @@ async fn completions(
yield Ok(event); yield Ok(event);
} }
Err(_err) => { Err(err) => yield Ok(Event::from(err)),
let event = Event::default();
yield Ok(event);
}
} }
} }
}; };
@ -1226,18 +1224,29 @@ async fn chat_completions(
// static values that will be returned in all cases // static values that will be returned in all cases
let model_id = info.model_id.clone(); let model_id = info.model_id.clone();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
let send_function_name = false; // TODO: fix to send function name
// switch on stream // switch on stream
if stream { if stream {
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await; generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
}),
))
}
};
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new(); let mut buffer = Vec::new();
let mut json_buffer = String::new(); let mut json_buffer = String::new();
// let mut content_buffer = String::new();
let mut state = if using_tools { let mut state = if using_tools {
StreamState::Buffering StreamState::Buffering
} else { } else {
@ -1246,20 +1255,13 @@ async fn chat_completions(
} }
}; };
let mut response_as_tool = using_tools; let mut response_as_tool = using_tools;
// Regex to match any function name
let function_regex = Regex::new(r#"\{"function":\{"_name":"([^"]+)""#).unwrap();
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
match result { if let Ok(stream_token) = result {
Ok(stream_token) => {
let token_text = &stream_token.token.text.clone(); let token_text = &stream_token.token.text.clone();
match state { match state {
StreamState::Buffering => { StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", "")); json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token); buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) { if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string(); let function_name = captures[1].to_string();
if function_name == "notify_error" { if function_name == "notify_error" {
@ -1271,32 +1273,6 @@ async fn chat_completions(
state = StreamState::Content { state = StreamState::Content {
skip_close_quote: false, skip_close_quote: false,
}; };
if send_function_name {
// send a message with the the function name
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete = CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
None,
Some(vec![function_name.clone()]),
current_time,
None,
None,
None,
),
);
let event = event.json_data(chat_complete).unwrap();
yield Ok(event);
}
// send all the buffered messages // send all the buffered messages
for stream_token in &buffer { for stream_token in &buffer {
let event = create_event_from_stream_token( let event = create_event_from_stream_token(
@ -1316,21 +1292,23 @@ async fn chat_completions(
StreamState::BufferTrailing => { StreamState::BufferTrailing => {
let infix_text = "\"error\":\""; let infix_text = "\"error\":\"";
json_buffer.push_str(&token_text.replace(" ", "")); json_buffer.push_str(&token_text.replace(" ", ""));
if !json_buffer.contains(infix_text) { // keep capturing until we find the infix text
continue; match json_buffer.find(infix_text) {
} Some(error_index) => {
let error_index = json_buffer.find(infix_text).unwrap();
json_buffer = json_buffer =
json_buffer[error_index + infix_text.len()..].to_string(); json_buffer[error_index + infix_text.len()..].to_string();
}
if json_buffer.is_empty() { None => {
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default(); let event = Event::default();
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let chat_complete = let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(), model_id.clone(),
@ -1342,11 +1320,13 @@ async fn chat_completions(
None, None,
None, None,
)); ));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
let event = event.json_data(chat_complete).unwrap(); InferError::StreamSerializationError(e.to_string()).into()
yield Ok(event); }));
} }
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content { state = StreamState::Content {
skip_close_quote: true, skip_close_quote: true,
}; };
@ -1370,11 +1350,6 @@ async fn chat_completions(
} }
} }
} }
Err(_err) => {
yield Ok::<Event, Infallible>(Event::default());
break;
}
}
} }
yield Ok::<Event, Infallible>(Event::default().data("[DONE]")); yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
}; };
@ -2507,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
}; };
( (
@ -2684,7 +2660,7 @@ mod tests {
); );
assert!(result.is_ok()); assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.unwrap(); let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true); assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
} }