diff --git a/router/src/server.rs b/router/src/server.rs index 2ed6fc25..d140509e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -641,9 +641,9 @@ async fn completions( }) .collect(); - let mut x_compute_type = "unknown".to_string(); + let mut x_compute_type = None; let mut x_compute_characters = 0u32; - let mut x_accel_buffering = "no".to_string(); + let mut x_accel_buffering = None; if stream { let mut response_streams = FuturesOrdered::new(); @@ -705,29 +705,28 @@ async fn completions( } }); - (index, header_rx, sse_rx) + (header_rx, sse_rx) }; response_streams.push_back(generate_future); } let mut all_rxs = vec![]; - while let Some((index, header_rx, sse_rx)) = response_streams.next().await { + while let Some((header_rx, sse_rx)) = response_streams.next().await { all_rxs.push(sse_rx); // get the headers from the first response of each stream let headers = header_rx.await.expect("Failed to get headers"); - if index == 0 { + if x_compute_type.is_none() { x_compute_type = headers .get("x-compute-type") .and_then(|v| v.to_str().ok()) - .unwrap_or("unknown") - .to_string(); + .map(|v| v.to_string()); + x_accel_buffering = headers .get("x-accel-buffering") .and_then(|v| v.to_str().ok()) - .unwrap_or("no") - .to_string(); + .map(|v| v.to_string()); } x_compute_characters += headers .get("x-compute-characters") @@ -737,9 +736,13 @@ async fn completions( } let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + if let Some(x_compute_type) = x_compute_type { + headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + } headers.insert("x-compute-characters", x_compute_characters.into()); - headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + if let Some(x_accel_buffering) = x_accel_buffering { + headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + } // now sink the sse streams into a single stream and remove the ones that are done let stream: AsyncStream, _> = async_stream::stream! { @@ -806,13 +809,20 @@ async fn completions( let choices = generate_responses .into_iter() .map(|(index, headers, Json(generation))| { - let details = generation.details.unwrap_or_default(); - if index == 0 { + let details = generation.details.ok_or(( + // this should never happen but handle if details are missing unexpectedly + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "No details in generation".to_string(), + error_type: "no details".to_string(), + }), + ))?; + + if x_compute_type.is_none() { x_compute_type = headers .get("x-compute-type") .and_then(|v| v.to_str().ok()) - .unwrap_or("unknown") - .to_string(); + .map(|v| v.to_string()); } // accumulate headers and usage from each response @@ -857,14 +867,15 @@ async fn completions( completion_tokens += details.generated_tokens; total_tokens += details.prefill.len() as u32 + details.generated_tokens; - CompletionComplete { + Ok(CompletionComplete { finish_reason: details.finish_reason.to_string(), index: index as u32, logprobs: None, text: generation.generated_text, - } + }) }) - .collect::>(); + .collect::, _>>() + .map_err(|(status, Json(err))| (status, Json(err)))?; let response = Completion { id: "".to_string(), @@ -886,7 +897,9 @@ async fn completions( // headers similar to `generate` but aggregated let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + if let Some(x_compute_type) = x_compute_type { + headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + } headers.insert("x-compute-characters", x_compute_characters.into()); headers.insert("x-total-time", x_total_time.into()); headers.insert("x-validation-time", x_validation_time.into()); @@ -895,8 +908,9 @@ async fn completions( headers.insert("x-time-per-token", x_time_per_token.into()); headers.insert("x-prompt-tokens", x_prompt_tokens.into()); headers.insert("x-generated-tokens", x_generated_tokens.into()); - headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); - + if let Some(x_accel_buffering) = x_accel_buffering { + headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + } Ok((headers, Json(response)).into_response()) } }