fix: always send event on error, avoid unwraps, refactor and improve tests

This commit is contained in:
David Holtz 2024-10-09 17:20:59 +00:00
parent 7d2aa27161
commit fa140a2eeb
4 changed files with 180 additions and 131 deletions

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " prompt",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728494305,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
)
count = 0
tool_calls_generated = ""
last_response = None
async for response in responses:
count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28
assert response == response_snapshot
assert last_response == response_snapshot
@pytest.mark.asyncio
@ -244,3 +253,44 @@ async def test_flash_llama_grammar_tools_insufficient_information(
)
assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 11
print(content_generated)
assert (
content_generated
== "There is no weather related function available to answer your prompt"
)
assert last_response == response_snapshot

View File

@ -355,6 +355,8 @@ pub enum InferError {
MissingTemplateVariable(String),
#[error("Tool error: {0}")]
ToolError(String),
#[error("Stream event serialization error")]
StreamSerializationError(String),
}
impl InferError {
@ -368,6 +370,7 @@ impl InferError {
InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error",
InferError::StreamSerializationError(_) => "stream_serialization_error",
}
}
}

View File

@ -459,10 +459,11 @@ async fn generate_stream(
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
while let Some(raw_event) = response_stream.next().await {
yield Ok(match raw_event {
Ok(token) => Event::default().json_data(token).unwrap(),
Err(err) => Event::from(err),
});
yield Ok(raw_event.map_or_else(Event::from, |token| {
Event::default()
.json_data(token)
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
}));
}
};
@ -847,10 +848,7 @@ async fn completions(
yield Ok(event);
}
Err(_err) => {
let event = Event::default();
yield Ok(event);
}
Err(err) => yield Ok(Event::from(err)),
}
}
};
@ -1226,18 +1224,29 @@ async fn chat_completions(
// static values that will be returned in all cases
let model_id = info.model_id.clone();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
let send_function_name = false; // TODO: fix to send function name
// switch on stream
if stream {
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
}),
))
}
};
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new();
// let mut content_buffer = String::new();
let mut state = if using_tools {
StreamState::Buffering
} else {
@ -1246,20 +1255,13 @@ async fn chat_completions(
}
};
let mut response_as_tool = using_tools;
// Regex to match any function name
let function_regex = Regex::new(r#"\{"function":\{"_name":"([^"]+)""#).unwrap();
while let Some(result) = response_stream.next().await {
match result {
Ok(stream_token) => {
if let Ok(stream_token) = result {
let token_text = &stream_token.token.text.clone();
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "notify_error" {
@ -1271,32 +1273,6 @@ async fn chat_completions(
state = StreamState::Content {
skip_close_quote: false,
};
if send_function_name {
// send a message with the the function name
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete = CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
None,
Some(vec![function_name.clone()]),
current_time,
None,
None,
None,
),
);
let event = event.json_data(chat_complete).unwrap();
yield Ok(event);
}
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
@ -1316,21 +1292,23 @@ async fn chat_completions(
StreamState::BufferTrailing => {
let infix_text = "\"error\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
if !json_buffer.contains(infix_text) {
continue;
}
let error_index = json_buffer.find(infix_text).unwrap();
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(error_index) => {
json_buffer =
json_buffer[error_index + infix_text.len()..].to_string();
if json_buffer.is_empty() {
}
None => {
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
@ -1342,11 +1320,13 @@ async fn chat_completions(
None,
None,
));
let event = event.json_data(chat_complete).unwrap();
yield Ok(event);
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
@ -1370,11 +1350,6 @@ async fn chat_completions(
}
}
}
Err(_err) => {
yield Ok::<Event, Infallible>(Event::default());
break;
}
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
};
@ -2507,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
(
@ -2684,7 +2660,7 @@ mod tests {
);
assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.unwrap();
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
}