mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: improve header init and error handling
This commit is contained in:
parent
25f5e788ae
commit
a62e30462b
@ -641,9 +641,9 @@ async fn completions(
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut x_compute_type = "unknown".to_string();
|
let mut x_compute_type = None;
|
||||||
let mut x_compute_characters = 0u32;
|
let mut x_compute_characters = 0u32;
|
||||||
let mut x_accel_buffering = "no".to_string();
|
let mut x_accel_buffering = None;
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
let mut response_streams = FuturesOrdered::new();
|
let mut response_streams = FuturesOrdered::new();
|
||||||
@ -705,29 +705,28 @@ async fn completions(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
(index, header_rx, sse_rx)
|
(header_rx, sse_rx)
|
||||||
};
|
};
|
||||||
response_streams.push_back(generate_future);
|
response_streams.push_back(generate_future);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut all_rxs = vec![];
|
let mut all_rxs = vec![];
|
||||||
|
|
||||||
while let Some((index, header_rx, sse_rx)) = response_streams.next().await {
|
while let Some((header_rx, sse_rx)) = response_streams.next().await {
|
||||||
all_rxs.push(sse_rx);
|
all_rxs.push(sse_rx);
|
||||||
|
|
||||||
// get the headers from the first response of each stream
|
// get the headers from the first response of each stream
|
||||||
let headers = header_rx.await.expect("Failed to get headers");
|
let headers = header_rx.await.expect("Failed to get headers");
|
||||||
if index == 0 {
|
if x_compute_type.is_none() {
|
||||||
x_compute_type = headers
|
x_compute_type = headers
|
||||||
.get("x-compute-type")
|
.get("x-compute-type")
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("unknown")
|
.map(|v| v.to_string());
|
||||||
.to_string();
|
|
||||||
x_accel_buffering = headers
|
x_accel_buffering = headers
|
||||||
.get("x-accel-buffering")
|
.get("x-accel-buffering")
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("no")
|
.map(|v| v.to_string());
|
||||||
.to_string();
|
|
||||||
}
|
}
|
||||||
x_compute_characters += headers
|
x_compute_characters += headers
|
||||||
.get("x-compute-characters")
|
.get("x-compute-characters")
|
||||||
@ -737,9 +736,13 @@ async fn completions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
if let Some(x_compute_type) = x_compute_type {
|
||||||
|
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||||
|
}
|
||||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
if let Some(x_accel_buffering) = x_accel_buffering {
|
||||||
|
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
// now sink the sse streams into a single stream and remove the ones that are done
|
// now sink the sse streams into a single stream and remove the ones that are done
|
||||||
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
|
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
|
||||||
@ -806,13 +809,20 @@ async fn completions(
|
|||||||
let choices = generate_responses
|
let choices = generate_responses
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(index, headers, Json(generation))| {
|
.map(|(index, headers, Json(generation))| {
|
||||||
let details = generation.details.unwrap_or_default();
|
let details = generation.details.ok_or((
|
||||||
if index == 0 {
|
// this should never happen but handle if details are missing unexpectedly
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "No details in generation".to_string(),
|
||||||
|
error_type: "no details".to_string(),
|
||||||
|
}),
|
||||||
|
))?;
|
||||||
|
|
||||||
|
if x_compute_type.is_none() {
|
||||||
x_compute_type = headers
|
x_compute_type = headers
|
||||||
.get("x-compute-type")
|
.get("x-compute-type")
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("unknown")
|
.map(|v| v.to_string());
|
||||||
.to_string();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// accumulate headers and usage from each response
|
// accumulate headers and usage from each response
|
||||||
@ -857,14 +867,15 @@ async fn completions(
|
|||||||
completion_tokens += details.generated_tokens;
|
completion_tokens += details.generated_tokens;
|
||||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||||
|
|
||||||
CompletionComplete {
|
Ok(CompletionComplete {
|
||||||
finish_reason: details.finish_reason.to_string(),
|
finish_reason: details.finish_reason.to_string(),
|
||||||
index: index as u32,
|
index: index as u32,
|
||||||
logprobs: None,
|
logprobs: None,
|
||||||
text: generation.generated_text,
|
text: generation.generated_text,
|
||||||
}
|
})
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|(status, Json(err))| (status, Json(err)))?;
|
||||||
|
|
||||||
let response = Completion {
|
let response = Completion {
|
||||||
id: "".to_string(),
|
id: "".to_string(),
|
||||||
@ -886,7 +897,9 @@ async fn completions(
|
|||||||
|
|
||||||
// headers similar to `generate` but aggregated
|
// headers similar to `generate` but aggregated
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
if let Some(x_compute_type) = x_compute_type {
|
||||||
|
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||||
|
}
|
||||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||||
headers.insert("x-total-time", x_total_time.into());
|
headers.insert("x-total-time", x_total_time.into());
|
||||||
headers.insert("x-validation-time", x_validation_time.into());
|
headers.insert("x-validation-time", x_validation_time.into());
|
||||||
@ -895,8 +908,9 @@ async fn completions(
|
|||||||
headers.insert("x-time-per-token", x_time_per_token.into());
|
headers.insert("x-time-per-token", x_time_per_token.into());
|
||||||
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
||||||
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
||||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
if let Some(x_accel_buffering) = x_accel_buffering {
|
||||||
|
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||||
|
}
|
||||||
Ok((headers, Json(response)).into_response())
|
Ok((headers, Json(response)).into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user