fix: always send event on error, avoid unwraps, refactor and improve tests

This commit is contained in:
David Holtz 2024-10-09 17:20:59 +00:00
parent 7d2aa27161
commit fa140a2eeb
4 changed files with 180 additions and 131 deletions

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " prompt",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728494305,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
) )
count = 0 count = 0
tool_calls_generated = ""
last_response = None
async for response in responses: async for response in responses:
count += 1 count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28 assert count == 28
assert response == response_snapshot assert last_response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@ -244,3 +253,44 @@ async def test_flash_llama_grammar_tools_insufficient_information(
) )
assert responses == response_snapshot assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 11
print(content_generated)
assert (
content_generated
== "There is no weather related function available to answer your prompt"
)
assert last_response == response_snapshot

View File

@ -355,6 +355,8 @@ pub enum InferError {
MissingTemplateVariable(String), MissingTemplateVariable(String),
#[error("Tool error: {0}")] #[error("Tool error: {0}")]
ToolError(String), ToolError(String),
#[error("Stream event serialization error")]
StreamSerializationError(String),
} }
impl InferError { impl InferError {
@ -368,6 +370,7 @@ impl InferError {
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",
InferError::StreamSerializationError(_) => "stream_serialization_error",
} }
} }
} }

View File

@ -459,10 +459,11 @@ async fn generate_stream(
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
while let Some(raw_event) = response_stream.next().await { while let Some(raw_event) = response_stream.next().await {
yield Ok(match raw_event { yield Ok(raw_event.map_or_else(Event::from, |token| {
Ok(token) => Event::default().json_data(token).unwrap(), Event::default()
Err(err) => Event::from(err), .json_data(token)
}); .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
}));
} }
}; };
@ -847,10 +848,7 @@ async fn completions(
yield Ok(event); yield Ok(event);
} }
Err(_err) => { Err(err) => yield Ok(Event::from(err)),
let event = Event::default();
yield Ok(event);
}
} }
} }
}; };
@ -1226,18 +1224,29 @@ async fn chat_completions(
// static values that will be returned in all cases // static values that will be returned in all cases
let model_id = info.model_id.clone(); let model_id = info.model_id.clone();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
let send_function_name = false; // TODO: fix to send function name
// switch on stream // switch on stream
if stream { if stream {
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await; generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
}),
))
}
};
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new(); let mut buffer = Vec::new();
let mut json_buffer = String::new(); let mut json_buffer = String::new();
// let mut content_buffer = String::new();
let mut state = if using_tools { let mut state = if using_tools {
StreamState::Buffering StreamState::Buffering
} else { } else {
@ -1246,133 +1255,99 @@ async fn chat_completions(
} }
}; };
let mut response_as_tool = using_tools; let mut response_as_tool = using_tools;
// Regex to match any function name
let function_regex = Regex::new(r#"\{"function":\{"_name":"([^"]+)""#).unwrap();
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
match result { if let Ok(stream_token) = result {
Ok(stream_token) => { let token_text = &stream_token.token.text.clone();
let token_text = &stream_token.token.text.clone(); match state {
StreamState::Buffering => {
match state { json_buffer.push_str(&token_text.replace(" ", ""));
StreamState::Buffering => { buffer.push(stream_token);
json_buffer.push_str(&token_text.replace(" ", "")); if let Some(captures) = function_regex.captures(&json_buffer) {
buffer.push(stream_token); let function_name = captures[1].to_string();
if function_name == "notify_error" {
if let Some(captures) = function_regex.captures(&json_buffer) { state = StreamState::BufferTrailing;
let function_name = captures[1].to_string(); response_as_tool = false;
if function_name == "notify_error" { buffer.clear();
state = StreamState::BufferTrailing; json_buffer.clear();
response_as_tool = false; } else {
buffer.clear(); state = StreamState::Content {
json_buffer.clear(); skip_close_quote: false,
} else { };
state = StreamState::Content { // send all the buffered messages
skip_close_quote: false, for stream_token in &buffer {
}; let event = create_event_from_stream_token(
stream_token,
if send_function_name { logprobs,
// send a message with the the function name stream_options.clone(),
let event = Event::default(); response_as_tool,
let current_time = std::time::SystemTime::now() system_fingerprint.clone(),
.duration_since(std::time::UNIX_EPOCH) model_id.clone(),
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) );
.as_secs(); yield Ok::<Event, Infallible>(event);
let chat_complete = CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
None,
Some(vec![function_name.clone()]),
current_time,
None,
None,
None,
),
);
let event = event.json_data(chat_complete).unwrap();
yield Ok(event);
}
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
} }
} }
} }
// if we skipped sending the buffer we need to avoid sending the following json key and quotes }
StreamState::BufferTrailing => { // if we skipped sending the buffer we need to avoid sending the following json key and quotes
let infix_text = "\"error\":\""; StreamState::BufferTrailing => {
json_buffer.push_str(&token_text.replace(" ", "")); let infix_text = "\"error\":\"";
if !json_buffer.contains(infix_text) { json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(error_index) => {
json_buffer =
json_buffer[error_index + infix_text.len()..].to_string();
}
None => {
continue; continue;
} }
let error_index = json_buffer.find(infix_text).unwrap();
json_buffer =
json_buffer[error_index + infix_text.len()..].to_string();
if json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
));
let event = event.json_data(chat_complete).unwrap();
yield Ok(event);
}
state = StreamState::Content {
skip_close_quote: true,
};
} }
StreamState::Content { skip_close_quote } => { // if there is leftover text after removing the infix text, we need to send it
if skip_close_quote && token_text.contains('"') { if !json_buffer.is_empty() {
break; let event = Event::default();
} let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
// send the content .unwrap_or_else(|_| std::time::Duration::from_secs(0))
let event = create_event_from_stream_token( .as_secs();
&stream_token, let chat_complete =
logprobs, CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
stream_options.clone(), model_id.clone(),
response_as_tool, system_fingerprint.clone(),
system_fingerprint.clone(), Some(json_buffer.clone()),
model_id.clone(), None,
); current_time,
None,
yield Ok::<Event, Infallible>(event); None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
} }
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
} }
}
Err(_err) => {
yield Ok::<Event, Infallible>(Event::default());
break;
} }
} }
} }
@ -2507,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
}; };
( (
@ -2684,7 +2660,7 @@ mod tests {
); );
assert!(result.is_ok()); assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.unwrap(); let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true); assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
} }