diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 0ffcd162..a8e261ad 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -269,6 +269,8 @@ class ResponseComparator(JSONSnapshotExtension): def eq_chat_complete_chunk( response: ChatCompletionChunk, other: ChatCompletionChunk ) -> bool: + if len(response.choices) == 0: + return len(other.choices) == 0 return response.choices[0].delta.content == other.choices[0].delta.content def eq_response(response: Response, other: Response) -> bool: diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json index 8c7be4cb..4ae1714a 100644 --- a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json @@ -12,11 +12,11 @@ "logprobs": null } ], - "created": 1726656043, + "created": 1740516693, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -32,11 +32,11 @@ "logprobs": null } ], - "created": 1726656043, + "created": 1740516693, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -52,11 +52,11 @@ "logprobs": null } ], - "created": 1726656043, + "created": 1740516693, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -72,11 +72,11 @@ "logprobs": null } ], - "created": 1726656043, + "created": 1740516694, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -92,11 +92,11 @@ "logprobs": null } ], - "created": 1726656043, + "created": 1740516694, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -112,11 +112,11 @@ "logprobs": null } ], - "created": 1726656043, + "created": 1740516694, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -132,11 +132,11 @@ "logprobs": null } ], - "created": 1726656044, + "created": 1740516694, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -152,11 +152,11 @@ "logprobs": null } ], - "created": 1726656044, + "created": 1740516694, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -172,11 +172,11 @@ "logprobs": null } ], - "created": 1726656044, + "created": 1740516694, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", "usage": null }, { @@ -192,11 +192,20 @@ "logprobs": null } ], - "created": 1726656044, + "created": 1740516694, "id": "", - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.2.1-dev0-native", + "system_fingerprint": "3.1.1-dev0-native", + "usage": null + }, + { + "choices": [], + "created": 1740516694, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "3.1.1-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 40, diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json index 4f10aa3b..bd091a52 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information.", + "content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast.", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1739932427, + "created": 1740516945, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion", "system_fingerprint": "3.1.1-dev0-native", "usage": { - "completion_tokens": 79, - "prompt_tokens": 103, - "total_tokens": 182 + "completion_tokens": 83, + "prompt_tokens": 109, + "total_tokens": 192 } } diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index 6c359f1e..320e77a9 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -91,7 +91,7 @@ async def test_flash_llama_completion_stream_usage( index = c["choices"][0]["index"] assert index == 0 string += c["choices"][0]["delta"]["content"] - + elif len(c["choices"]) == 0: has_usage = c["usage"] is not None assert not had_usage if has_usage: @@ -142,7 +142,7 @@ async def test_flash_llama_completion_stream_usage( index = c["choices"][0]["index"] assert index == 0 string += c["choices"][0]["delta"]["content"] - + elif len(c["choices"]) == 0: has_usage = c["usage"] is not None assert not had_usage if has_usage: diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index ec5c8c11..eed6587b 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -497,7 +497,7 @@ async def test_flash_llama_tool_reply_response( assert responses.choices[0].message.tool_calls is None assert ( responses.choices[0].message.content - == "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information." + == "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast." ) assert responses == response_snapshot diff --git a/router/src/server.rs b/router/src/server.rs index b74699db..99a87b7a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1152,10 +1152,10 @@ fn complete_json(partial: &str) -> (String, bool) { } // Generic function that parses any partial structure into a Map -fn parse_generic_structure(partial: &str) -> Result, String> { - let (completed, _) = complete_json(partial); +fn parse_generic_structure(partial: &str) -> Result<(Map, bool), String> { + let (completed, quote_open) = complete_json(partial); match serde_json::from_str::(&completed) { - Ok(Value::Object(obj)) => Ok(obj), + Ok(Value::Object(obj)) => Ok((obj, quote_open)), _ => Err("Failed to parse as object".to_string()), } } @@ -1335,24 +1335,32 @@ pub(crate) async fn chat_completions( ) -> Result)> { let span = tracing::Span::current(); metrics::counter!("tgi_request_count").increment(1); - let ChatRequest { - model, - stream, - logprobs, - // TODO: add back and maybe consolidate the other PR - // stream_options, - .. - } = chat.clone(); - let (generate_request, using_tools) = chat.try_into_generate(&infer)?; - let logprobs = logprobs.unwrap_or_default(); - // extract model id from request if specified - let model_id = match model.as_deref() { - Some("tgi") | None => info.model_id.clone(), - Some(m_id) => m_id.to_string(), - }; + // Extract needed fields + let model = chat.model.clone(); + let logprobs = chat.logprobs.unwrap_or_default(); + let stream = chat.stream; + let stream_options = chat.stream_options.clone(); + + // Process request (this consumes chat) + let (generate_request, using_tools) = chat.try_into_generate(&infer)?; + + // Determine model ID + let model_id = model + .as_deref() + .filter(|&m| m != "tgi") + .unwrap_or(&info.model_id) + .to_string(); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); + // Helper function to get current timestamp + let get_timestamp = || { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs() + }; + if stream { let (headers, response_stream) = generate_stream_internal(infer, compute_type, Json(generate_request), span).await; @@ -1364,103 +1372,198 @@ pub(crate) async fn chat_completions( let mut no_tool_chosen = false; let mut first_quote_removed = false; - while let Some(result) = response_stream.next().await { - match result { - Ok(stream_token) => { - let token_text = stream_token.token.text.clone(); - json_buffer.push_str(&token_text); - if !name_found { - // since we know tools is attempting to follow a grammar we can attempt to - // partially parse the json_buffer to see if we can extract the function name - if let Ok(function) = parse_partial_json(&json_buffer) { - let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool"); - name_found = true; - if name == "no_tool" { - no_tool_chosen = true; - json_buffer.clear(); - json_buffer.push('{'); - } else { - let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None); - yield Ok::(tool_name_event); - let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None); - yield Ok::(tool_open_arguments_event); - // clear the buffer as we know that the buffer is only the function - // ie: ` {"function": {"_name": "get_current_weather",` -> `{"` - // we need to keep the `{` to open the arguments and allow the parser to continue - json_buffer.clear(); - json_buffer.push('{'); - } - } + // Process stream tokens + while let Some(Ok(stream_token)) = response_stream.next().await { + let token_text = stream_token.token.text.clone(); + let mut events = Vec::new(); + let mut should_break = false; + + // Get usage information + let usage = stream_token.details.as_ref().map(|d| Usage { + completion_tokens: d.generated_tokens, + prompt_tokens: d.input_length, + total_tokens: d.input_length + d.generated_tokens, + }); + + json_buffer.push_str(&token_text); + + // Phase 1: Function name discovery + if !name_found { + if let Ok(function) = parse_partial_json(&json_buffer) { + name_found = true; + + let name = function + .get("_name") + .and_then(|n| n.as_str()) + .unwrap_or_default(); + if name == "no_tool" { + no_tool_chosen = true; } else { - // Process JSON buffer and handle token text - let last_is_brace = json_buffer.ends_with('}'); - let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer }; - let mut token_text = stream_token.token.text.clone(); - let is_json_complete = serde_json::from_str::(edited_buffer).is_ok(); + events.push(create_event( + &token_text, + &model_id, + &system_fingerprint, + Some(name), + false, + None, + )); + events.push(create_event( + "{", + &model_id, + &system_fingerprint, + None, + true, + None, + )); + } - // Handle tool usage cases - if using_tools { - if no_tool_chosen { - // Tool without selection ("content" flow) - if let Ok(function) = parse_generic_structure(edited_buffer) { - if function.get("content").and_then(|c| c.as_str()).is_some() { - // Handle quotation marks - if !first_quote_removed { - first_quote_removed = true; - if token_text == "\"" || token_text == " \"" { continue; } - token_text = token_text.replace("\"", ""); - } else if token_text.ends_with('"') { - token_text = token_text[..token_text.len() - 1].to_string(); - } + // Reset buffer for arguments + json_buffer.clear(); + json_buffer.push('{'); + } - if is_json_complete { break; } - yield Ok::(create_event(&token_text, &model_id, &system_fingerprint, None, false, None)); - continue; - } - } - continue; + for event in events { + yield Ok::(event); + } + continue; + } + + // Phase 2: Content processing + let is_complete_json = json_buffer.ends_with('}') + && serde_json::from_str::(&json_buffer[..json_buffer.len() - 1]).is_ok(); + let mut edited_token = token_text; + + // Handle different flows based on context + if using_tools { + if no_tool_chosen && !is_complete_json { + // Content-only flow + if let Ok((function, quote_open)) = parse_generic_structure(&json_buffer) { + if let Some(_content) = function.get("content").and_then(|c| c.as_str()) { + let cleaned_token = if !first_quote_removed { + // trim start unil the first quote + first_quote_removed = true; + edited_token + .trim_start() + .strip_prefix('"') + .unwrap_or(&edited_token) + .to_string() + } else if !quote_open { + should_break = true; + // trim end until the last quote + edited_token + .trim_end() + .strip_suffix('"') + .unwrap_or(&edited_token) + .to_string() } else { - // Tool with selection - if is_json_complete { - // Final token with possible brace removal - if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); } - yield Ok::(create_event(&token_text, &model_id, &system_fingerprint, None, true, None)); - break; - } - yield Ok::(create_event(&token_text, &model_id, &system_fingerprint, None, true, None)); - continue; + edited_token.to_string() + }; + + if !cleaned_token.is_empty() { + events.push(create_event( + &cleaned_token, + &model_id, + &system_fingerprint, + None, + false, + None, + )); } - } else { - // Default case: standard chat completion - if let Some(details) = stream_token.details.as_ref() { - // Handle final token and only send text if ended because of length - let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" }; - yield Ok::(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true)))); - break; - } - yield Ok::(create_event(&token_text, &model_id, &system_fingerprint, None, false, None)); } } + } else { + // Tool with arguments flow + if is_complete_json { + edited_token.truncate(edited_token.len() - 1); + should_break = true; + } + events.push(create_event( + &edited_token, + &model_id, + &system_fingerprint, + None, + true, + None, + )); } - Err(err) => yield Ok(err.into_openai_event()), + } else { + // Standard chat completion flow + if let Some(details) = stream_token.details.as_ref() { + let finish_reason = details.finish_reason.format(true); + let text = if details.finish_reason == FinishReason::Length { + &edited_token + } else { + "" + }; + events.push(create_event( + text, + &model_id, + &system_fingerprint, + None, + false, + Some(finish_reason), + )); + should_break = true; + } else { + events.push(create_event( + &edited_token, + &model_id, + &system_fingerprint, + None, + false, + None, + )); + } + } + + // Emit all collected events + for event in events { + yield Ok::(event); + } + + // Emit usage data when requested + if let (Some(usage_data), true) = ( + usage, + stream_options.as_ref().is_some_and(|o| o.include_usage) + ) { + let current_time = get_timestamp(); + + let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk { + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![], + usage: Some(usage_data), + }); + + yield Ok(Event::default() + .json_data(chat_complete) + .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())); + } + if should_break { + break; } } + + // Handle any errors in the stream + if let Some(Err(err)) = response_stream.next().await { + yield Ok(err.into_openai_event()); + } + + // Send final completion signal yield Ok::(Event::default().data("[DONE]")); }; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - // Non-streaming case + // Non-streaming response path let (headers, input_length, Json(generation)) = generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - let (tool_calls, output) = if using_tools { + // Parse generated JSON text let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| { InferError::ToolError(format!( @@ -1468,6 +1571,8 @@ pub(crate) async fn chat_completions( e, generation.generated_text )) })?; + + // Extract function details let function = gen_text_value.get("function").ok_or(InferError::ToolError( "No function found in generated text".to_string(), ))?; @@ -1480,24 +1585,28 @@ pub(crate) async fn chat_completions( ))? .to_string(); + // Prepare arguments (clone and remove _name) let mut arguments = function.clone(); if let Value::Object(ref mut props) = arguments { props.remove("_name"); } + + // Process based on tool name match name.as_str() { "no_tool" => { - // parse the content message - let content_message = arguments + // Extract content for no-tool case + let content = arguments .get("content") .and_then(Value::as_str) .ok_or(InferError::ToolError( "No `content` found in generated text".to_string(), ))? .to_string(); - (None, Some(content_message)) + (None, Some(content)) } _ => { - let tool_calls = vec![ToolCall { + // Create tool call for normal function case + let tool_call = ToolCall { id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { @@ -1505,26 +1614,27 @@ pub(crate) async fn chat_completions( name, arguments, }, - }]; - (Some(tool_calls), None) + }; + (Some(vec![tool_call]), None) } } } else { + // Standard text output (None, Some(generation.generated_text)) }; - // build the complete response object with the full text + + // Build complete response with all details let response = CompletionType::ChatCompletion(ChatCompletion::new( model_id, system_fingerprint, output, - current_time, + get_timestamp(), generation.details.unwrap(), logprobs, tool_calls, input_length, )); - // wrap generation inside a Vec to match api-inference Ok((headers, Json(response)).into_response()) } }