mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 04:52:07 +00:00
feat: refactor and simplify chat stream more, bump tests and support stream_options
This commit is contained in:
parent
c4cb54c23e
commit
a5ddc9db52
@ -269,6 +269,8 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
def eq_chat_complete_chunk(
|
||||
response: ChatCompletionChunk, other: ChatCompletionChunk
|
||||
) -> bool:
|
||||
if len(response.choices) == 0:
|
||||
return len(other.choices) == 0
|
||||
return response.choices[0].delta.content == other.choices[0].delta.content
|
||||
|
||||
def eq_response(response: Response, other: Response) -> bool:
|
||||
|
@ -12,11 +12,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656043,
|
||||
"created": 1740516693,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -32,11 +32,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656043,
|
||||
"created": 1740516693,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -52,11 +52,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656043,
|
||||
"created": 1740516693,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -72,11 +72,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656043,
|
||||
"created": 1740516694,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -92,11 +92,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656043,
|
||||
"created": 1740516694,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -112,11 +112,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656043,
|
||||
"created": 1740516694,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -132,11 +132,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656044,
|
||||
"created": 1740516694,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -152,11 +152,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656044,
|
||||
"created": 1740516694,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -172,11 +172,11 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656044,
|
||||
"created": 1740516694,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
@ -192,11 +192,20 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1726656044,
|
||||
"created": 1740516694,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [],
|
||||
"created": 1740516694,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 40,
|
||||
|
@ -5,7 +5,7 @@
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information.",
|
||||
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast.",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
@ -13,14 +13,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1739932427,
|
||||
"created": 1740516945,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 79,
|
||||
"prompt_tokens": 103,
|
||||
"total_tokens": 182
|
||||
"completion_tokens": 83,
|
||||
"prompt_tokens": 109,
|
||||
"total_tokens": 192
|
||||
}
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ async def test_flash_llama_completion_stream_usage(
|
||||
index = c["choices"][0]["index"]
|
||||
assert index == 0
|
||||
string += c["choices"][0]["delta"]["content"]
|
||||
|
||||
elif len(c["choices"]) == 0:
|
||||
has_usage = c["usage"] is not None
|
||||
assert not had_usage
|
||||
if has_usage:
|
||||
@ -142,7 +142,7 @@ async def test_flash_llama_completion_stream_usage(
|
||||
index = c["choices"][0]["index"]
|
||||
assert index == 0
|
||||
string += c["choices"][0]["delta"]["content"]
|
||||
|
||||
elif len(c["choices"]) == 0:
|
||||
has_usage = c["usage"] is not None
|
||||
assert not had_usage
|
||||
if has_usage:
|
||||
|
@ -497,7 +497,7 @@ async def test_flash_llama_tool_reply_response(
|
||||
assert responses.choices[0].message.tool_calls is None
|
||||
assert (
|
||||
responses.choices[0].message.content
|
||||
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information."
|
||||
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast."
|
||||
)
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
@ -1152,10 +1152,10 @@ fn complete_json(partial: &str) -> (String, bool) {
|
||||
}
|
||||
|
||||
// Generic function that parses any partial structure into a Map
|
||||
fn parse_generic_structure(partial: &str) -> Result<Map<String, Value>, String> {
|
||||
let (completed, _) = complete_json(partial);
|
||||
fn parse_generic_structure(partial: &str) -> Result<(Map<String, Value>, bool), String> {
|
||||
let (completed, quote_open) = complete_json(partial);
|
||||
match serde_json::from_str::<Value>(&completed) {
|
||||
Ok(Value::Object(obj)) => Ok(obj),
|
||||
Ok(Value::Object(obj)) => Ok((obj, quote_open)),
|
||||
_ => Err("Failed to parse as object".to_string()),
|
||||
}
|
||||
}
|
||||
@ -1335,24 +1335,32 @@ pub(crate) async fn chat_completions(
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
let ChatRequest {
|
||||
model,
|
||||
stream,
|
||||
logprobs,
|
||||
// TODO: add back and maybe consolidate the other PR
|
||||
// stream_options,
|
||||
..
|
||||
} = chat.clone();
|
||||
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
|
||||
let logprobs = logprobs.unwrap_or_default();
|
||||
|
||||
// extract model id from request if specified
|
||||
let model_id = match model.as_deref() {
|
||||
Some("tgi") | None => info.model_id.clone(),
|
||||
Some(m_id) => m_id.to_string(),
|
||||
};
|
||||
// Extract needed fields
|
||||
let model = chat.model.clone();
|
||||
let logprobs = chat.logprobs.unwrap_or_default();
|
||||
let stream = chat.stream;
|
||||
let stream_options = chat.stream_options.clone();
|
||||
|
||||
// Process request (this consumes chat)
|
||||
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
|
||||
|
||||
// Determine model ID
|
||||
let model_id = model
|
||||
.as_deref()
|
||||
.filter(|&m| m != "tgi")
|
||||
.unwrap_or(&info.model_id)
|
||||
.to_string();
|
||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
|
||||
// Helper function to get current timestamp
|
||||
let get_timestamp = || {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs()
|
||||
};
|
||||
|
||||
if stream {
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
@ -1364,103 +1372,198 @@ pub(crate) async fn chat_completions(
|
||||
let mut no_tool_chosen = false;
|
||||
let mut first_quote_removed = false;
|
||||
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(stream_token) => {
|
||||
let token_text = stream_token.token.text.clone();
|
||||
json_buffer.push_str(&token_text);
|
||||
if !name_found {
|
||||
// since we know tools is attempting to follow a grammar we can attempt to
|
||||
// partially parse the json_buffer to see if we can extract the function name
|
||||
if let Ok(function) = parse_partial_json(&json_buffer) {
|
||||
let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool");
|
||||
name_found = true;
|
||||
if name == "no_tool" {
|
||||
no_tool_chosen = true;
|
||||
json_buffer.clear();
|
||||
json_buffer.push('{');
|
||||
} else {
|
||||
let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None);
|
||||
yield Ok::<Event, Infallible>(tool_name_event);
|
||||
let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None);
|
||||
yield Ok::<Event, Infallible>(tool_open_arguments_event);
|
||||
// clear the buffer as we know that the buffer is only the function
|
||||
// ie: ` {"function": {"_name": "get_current_weather",` -> `{"`
|
||||
// we need to keep the `{` to open the arguments and allow the parser to continue
|
||||
json_buffer.clear();
|
||||
json_buffer.push('{');
|
||||
}
|
||||
}
|
||||
// Process stream tokens
|
||||
while let Some(Ok(stream_token)) = response_stream.next().await {
|
||||
let token_text = stream_token.token.text.clone();
|
||||
let mut events = Vec::new();
|
||||
let mut should_break = false;
|
||||
|
||||
// Get usage information
|
||||
let usage = stream_token.details.as_ref().map(|d| Usage {
|
||||
completion_tokens: d.generated_tokens,
|
||||
prompt_tokens: d.input_length,
|
||||
total_tokens: d.input_length + d.generated_tokens,
|
||||
});
|
||||
|
||||
json_buffer.push_str(&token_text);
|
||||
|
||||
// Phase 1: Function name discovery
|
||||
if !name_found {
|
||||
if let Ok(function) = parse_partial_json(&json_buffer) {
|
||||
name_found = true;
|
||||
|
||||
let name = function
|
||||
.get("_name")
|
||||
.and_then(|n| n.as_str())
|
||||
.unwrap_or_default();
|
||||
if name == "no_tool" {
|
||||
no_tool_chosen = true;
|
||||
} else {
|
||||
// Process JSON buffer and handle token text
|
||||
let last_is_brace = json_buffer.ends_with('}');
|
||||
let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer };
|
||||
let mut token_text = stream_token.token.text.clone();
|
||||
let is_json_complete = serde_json::from_str::<Value>(edited_buffer).is_ok();
|
||||
events.push(create_event(
|
||||
&token_text,
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
Some(name),
|
||||
false,
|
||||
None,
|
||||
));
|
||||
events.push(create_event(
|
||||
"{",
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
None,
|
||||
true,
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
// Handle tool usage cases
|
||||
if using_tools {
|
||||
if no_tool_chosen {
|
||||
// Tool without selection ("content" flow)
|
||||
if let Ok(function) = parse_generic_structure(edited_buffer) {
|
||||
if function.get("content").and_then(|c| c.as_str()).is_some() {
|
||||
// Handle quotation marks
|
||||
if !first_quote_removed {
|
||||
first_quote_removed = true;
|
||||
if token_text == "\"" || token_text == " \"" { continue; }
|
||||
token_text = token_text.replace("\"", "");
|
||||
} else if token_text.ends_with('"') {
|
||||
token_text = token_text[..token_text.len() - 1].to_string();
|
||||
}
|
||||
// Reset buffer for arguments
|
||||
json_buffer.clear();
|
||||
json_buffer.push('{');
|
||||
}
|
||||
|
||||
if is_json_complete { break; }
|
||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
for event in events {
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Phase 2: Content processing
|
||||
let is_complete_json = json_buffer.ends_with('}')
|
||||
&& serde_json::from_str::<Value>(&json_buffer[..json_buffer.len() - 1]).is_ok();
|
||||
let mut edited_token = token_text;
|
||||
|
||||
// Handle different flows based on context
|
||||
if using_tools {
|
||||
if no_tool_chosen && !is_complete_json {
|
||||
// Content-only flow
|
||||
if let Ok((function, quote_open)) = parse_generic_structure(&json_buffer) {
|
||||
if let Some(_content) = function.get("content").and_then(|c| c.as_str()) {
|
||||
let cleaned_token = if !first_quote_removed {
|
||||
// trim start unil the first quote
|
||||
first_quote_removed = true;
|
||||
edited_token
|
||||
.trim_start()
|
||||
.strip_prefix('"')
|
||||
.unwrap_or(&edited_token)
|
||||
.to_string()
|
||||
} else if !quote_open {
|
||||
should_break = true;
|
||||
// trim end until the last quote
|
||||
edited_token
|
||||
.trim_end()
|
||||
.strip_suffix('"')
|
||||
.unwrap_or(&edited_token)
|
||||
.to_string()
|
||||
} else {
|
||||
// Tool with selection
|
||||
if is_json_complete {
|
||||
// Final token with possible brace removal
|
||||
if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); }
|
||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
|
||||
break;
|
||||
}
|
||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
|
||||
continue;
|
||||
edited_token.to_string()
|
||||
};
|
||||
|
||||
if !cleaned_token.is_empty() {
|
||||
events.push(create_event(
|
||||
&cleaned_token,
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
));
|
||||
}
|
||||
} else {
|
||||
// Default case: standard chat completion
|
||||
if let Some(details) = stream_token.details.as_ref() {
|
||||
// Handle final token and only send text if ended because of length
|
||||
let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" };
|
||||
yield Ok::<Event, Infallible>(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true))));
|
||||
break;
|
||||
}
|
||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Tool with arguments flow
|
||||
if is_complete_json {
|
||||
edited_token.truncate(edited_token.len() - 1);
|
||||
should_break = true;
|
||||
}
|
||||
events.push(create_event(
|
||||
&edited_token,
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
None,
|
||||
true,
|
||||
None,
|
||||
));
|
||||
}
|
||||
Err(err) => yield Ok(err.into_openai_event()),
|
||||
} else {
|
||||
// Standard chat completion flow
|
||||
if let Some(details) = stream_token.details.as_ref() {
|
||||
let finish_reason = details.finish_reason.format(true);
|
||||
let text = if details.finish_reason == FinishReason::Length {
|
||||
&edited_token
|
||||
} else {
|
||||
""
|
||||
};
|
||||
events.push(create_event(
|
||||
text,
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
None,
|
||||
false,
|
||||
Some(finish_reason),
|
||||
));
|
||||
should_break = true;
|
||||
} else {
|
||||
events.push(create_event(
|
||||
&edited_token,
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Emit all collected events
|
||||
for event in events {
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
|
||||
// Emit usage data when requested
|
||||
if let (Some(usage_data), true) = (
|
||||
usage,
|
||||
stream_options.as_ref().is_some_and(|o| o.include_usage)
|
||||
) {
|
||||
let current_time = get_timestamp();
|
||||
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![],
|
||||
usage: Some(usage_data),
|
||||
});
|
||||
|
||||
yield Ok(Event::default()
|
||||
.json_data(chat_complete)
|
||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()));
|
||||
}
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle any errors in the stream
|
||||
if let Some(Err(err)) = response_stream.next().await {
|
||||
yield Ok(err.into_openai_event());
|
||||
}
|
||||
|
||||
// Send final completion signal
|
||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||
};
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
// Non-streaming case
|
||||
// Non-streaming response path
|
||||
let (headers, input_length, Json(generation)) =
|
||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let (tool_calls, output) = if using_tools {
|
||||
// Parse generated JSON text
|
||||
let gen_text_value: Value =
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
InferError::ToolError(format!(
|
||||
@ -1468,6 +1571,8 @@ pub(crate) async fn chat_completions(
|
||||
e, generation.generated_text
|
||||
))
|
||||
})?;
|
||||
|
||||
// Extract function details
|
||||
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
|
||||
"No function found in generated text".to_string(),
|
||||
))?;
|
||||
@ -1480,24 +1585,28 @@ pub(crate) async fn chat_completions(
|
||||
))?
|
||||
.to_string();
|
||||
|
||||
// Prepare arguments (clone and remove _name)
|
||||
let mut arguments = function.clone();
|
||||
if let Value::Object(ref mut props) = arguments {
|
||||
props.remove("_name");
|
||||
}
|
||||
|
||||
// Process based on tool name
|
||||
match name.as_str() {
|
||||
"no_tool" => {
|
||||
// parse the content message
|
||||
let content_message = arguments
|
||||
// Extract content for no-tool case
|
||||
let content = arguments
|
||||
.get("content")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or(InferError::ToolError(
|
||||
"No `content` found in generated text".to_string(),
|
||||
))?
|
||||
.to_string();
|
||||
(None, Some(content_message))
|
||||
(None, Some(content))
|
||||
}
|
||||
_ => {
|
||||
let tool_calls = vec![ToolCall {
|
||||
// Create tool call for normal function case
|
||||
let tool_call = ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
@ -1505,26 +1614,27 @@ pub(crate) async fn chat_completions(
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
};
|
||||
(Some(vec![tool_call]), None)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Standard text output
|
||||
(None, Some(generation.generated_text))
|
||||
};
|
||||
// build the complete response object with the full text
|
||||
|
||||
// Build complete response with all details
|
||||
let response = CompletionType::ChatCompletion(ChatCompletion::new(
|
||||
model_id,
|
||||
system_fingerprint,
|
||||
output,
|
||||
current_time,
|
||||
get_timestamp(),
|
||||
generation.details.unwrap(),
|
||||
logprobs,
|
||||
tool_calls,
|
||||
input_length,
|
||||
));
|
||||
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user