fix: improve header init and error handling

This commit is contained in:
drbh 2024-04-11 21:18:14 +00:00
parent 25f5e788ae
commit a62e30462b

View File

@ -641,9 +641,9 @@ async fn completions(
}) })
.collect(); .collect();
let mut x_compute_type = "unknown".to_string(); let mut x_compute_type = None;
let mut x_compute_characters = 0u32; let mut x_compute_characters = 0u32;
let mut x_accel_buffering = "no".to_string(); let mut x_accel_buffering = None;
if stream { if stream {
let mut response_streams = FuturesOrdered::new(); let mut response_streams = FuturesOrdered::new();
@ -705,29 +705,28 @@ async fn completions(
} }
}); });
(index, header_rx, sse_rx) (header_rx, sse_rx)
}; };
response_streams.push_back(generate_future); response_streams.push_back(generate_future);
} }
let mut all_rxs = vec![]; let mut all_rxs = vec![];
while let Some((index, header_rx, sse_rx)) = response_streams.next().await { while let Some((header_rx, sse_rx)) = response_streams.next().await {
all_rxs.push(sse_rx); all_rxs.push(sse_rx);
// get the headers from the first response of each stream // get the headers from the first response of each stream
let headers = header_rx.await.expect("Failed to get headers"); let headers = header_rx.await.expect("Failed to get headers");
if index == 0 { if x_compute_type.is_none() {
x_compute_type = headers x_compute_type = headers
.get("x-compute-type") .get("x-compute-type")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("unknown") .map(|v| v.to_string());
.to_string();
x_accel_buffering = headers x_accel_buffering = headers
.get("x-accel-buffering") .get("x-accel-buffering")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("no") .map(|v| v.to_string());
.to_string();
} }
x_compute_characters += headers x_compute_characters += headers
.get("x-compute-characters") .get("x-compute-characters")
@ -737,9 +736,13 @@ async fn completions(
} }
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap()); if let Some(x_compute_type) = x_compute_type {
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
}
headers.insert("x-compute-characters", x_compute_characters.into()); headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); if let Some(x_accel_buffering) = x_accel_buffering {
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
}
// now sink the sse streams into a single stream and remove the ones that are done // now sink the sse streams into a single stream and remove the ones that are done
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! { let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
@ -806,13 +809,20 @@ async fn completions(
let choices = generate_responses let choices = generate_responses
.into_iter() .into_iter()
.map(|(index, headers, Json(generation))| { .map(|(index, headers, Json(generation))| {
let details = generation.details.unwrap_or_default(); let details = generation.details.ok_or((
if index == 0 { // this should never happen but handle if details are missing unexpectedly
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "No details in generation".to_string(),
error_type: "no details".to_string(),
}),
))?;
if x_compute_type.is_none() {
x_compute_type = headers x_compute_type = headers
.get("x-compute-type") .get("x-compute-type")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("unknown") .map(|v| v.to_string());
.to_string();
} }
// accumulate headers and usage from each response // accumulate headers and usage from each response
@ -857,14 +867,15 @@ async fn completions(
completion_tokens += details.generated_tokens; completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens; total_tokens += details.prefill.len() as u32 + details.generated_tokens;
CompletionComplete { Ok(CompletionComplete {
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.to_string(),
index: index as u32, index: index as u32,
logprobs: None, logprobs: None,
text: generation.generated_text, text: generation.generated_text,
} })
}) })
.collect::<Vec<_>>(); .collect::<Result<Vec<_>, _>>()
.map_err(|(status, Json(err))| (status, Json(err)))?;
let response = Completion { let response = Completion {
id: "".to_string(), id: "".to_string(),
@ -886,7 +897,9 @@ async fn completions(
// headers similar to `generate` but aggregated // headers similar to `generate` but aggregated
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap()); if let Some(x_compute_type) = x_compute_type {
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
}
headers.insert("x-compute-characters", x_compute_characters.into()); headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-total-time", x_total_time.into()); headers.insert("x-total-time", x_total_time.into());
headers.insert("x-validation-time", x_validation_time.into()); headers.insert("x-validation-time", x_validation_time.into());
@ -895,8 +908,9 @@ async fn completions(
headers.insert("x-time-per-token", x_time_per_token.into()); headers.insert("x-time-per-token", x_time_per_token.into());
headers.insert("x-prompt-tokens", x_prompt_tokens.into()); headers.insert("x-prompt-tokens", x_prompt_tokens.into());
headers.insert("x-generated-tokens", x_generated_tokens.into()); headers.insert("x-generated-tokens", x_generated_tokens.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); if let Some(x_accel_buffering) = x_accel_buffering {
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
}
Ok((headers, Json(response)).into_response()) Ok((headers, Json(response)).into_response())
} }
} }