diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index daa1209a..feec79da 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -39,7 +39,7 @@ def test_flash_llama_completion_single_prompt( response = response.json() assert len(response["choices"]) == 1 - response == response_snapshot + return response == response_snapshot def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): @@ -61,7 +61,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn all_indexes.sort() assert all_indexes == [0, 1, 2, 3] - response == response_snapshot + return response == response_snapshot async def test_flash_llama_completion_many_prompts_stream( @@ -100,4 +100,4 @@ async def test_flash_llama_completion_many_prompts_stream( assert 0 <= c["choices"][0]["index"] <= 4 assert response.status == 200 - response == response_snapshot + return response == response_snapshot diff --git a/router/src/server.rs b/router/src/server.rs index d140509e..07fbacfc 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -701,7 +701,10 @@ async fn completions( // pin an emit messages to the sse_tx let mut sse = Box::pin(sse); while let Some(event) = sse.next().await { - sse_tx.send(event).expect("Failed to send event"); + if sse_tx.send(event).is_err() { + tracing::error!("Failed to send event. Receiver dropped."); + break; + } } }); @@ -716,7 +719,16 @@ async fn completions( all_rxs.push(sse_rx); // get the headers from the first response of each stream - let headers = header_rx.await.expect("Failed to get headers"); + let headers = header_rx.await.map_err(|e| { + tracing::error!("Failed to get headers: {:?}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to get headers".to_string(), + error_type: "headers".to_string(), + }), + ) + })?; if x_compute_type.is_none() { x_compute_type = headers .get("x-compute-type")