mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: graceful stream close and fix tests
This commit is contained in:
parent
c7b4cd318f
commit
f2080c4114
@ -39,7 +39,7 @@ def test_flash_llama_completion_single_prompt(
|
|||||||
response = response.json()
|
response = response.json()
|
||||||
assert len(response["choices"]) == 1
|
assert len(response["choices"]) == 1
|
||||||
|
|
||||||
response == response_snapshot
|
return response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
@ -61,7 +61,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||||||
all_indexes.sort()
|
all_indexes.sort()
|
||||||
assert all_indexes == [0, 1, 2, 3]
|
assert all_indexes == [0, 1, 2, 3]
|
||||||
|
|
||||||
response == response_snapshot
|
return response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_flash_llama_completion_many_prompts_stream(
|
async def test_flash_llama_completion_many_prompts_stream(
|
||||||
@ -100,4 +100,4 @@ async def test_flash_llama_completion_many_prompts_stream(
|
|||||||
assert 0 <= c["choices"][0]["index"] <= 4
|
assert 0 <= c["choices"][0]["index"] <= 4
|
||||||
|
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
response == response_snapshot
|
return response == response_snapshot
|
||||||
|
@ -701,7 +701,10 @@ async fn completions(
|
|||||||
// pin an emit messages to the sse_tx
|
// pin an emit messages to the sse_tx
|
||||||
let mut sse = Box::pin(sse);
|
let mut sse = Box::pin(sse);
|
||||||
while let Some(event) = sse.next().await {
|
while let Some(event) = sse.next().await {
|
||||||
sse_tx.send(event).expect("Failed to send event");
|
if sse_tx.send(event).is_err() {
|
||||||
|
tracing::error!("Failed to send event. Receiver dropped.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -716,7 +719,16 @@ async fn completions(
|
|||||||
all_rxs.push(sse_rx);
|
all_rxs.push(sse_rx);
|
||||||
|
|
||||||
// get the headers from the first response of each stream
|
// get the headers from the first response of each stream
|
||||||
let headers = header_rx.await.expect("Failed to get headers");
|
let headers = header_rx.await.map_err(|e| {
|
||||||
|
tracing::error!("Failed to get headers: {:?}", e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Failed to get headers".to_string(),
|
||||||
|
error_type: "headers".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
if x_compute_type.is_none() {
|
if x_compute_type.is_none() {
|
||||||
x_compute_type = headers
|
x_compute_type = headers
|
||||||
.get("x-compute-type")
|
.get("x-compute-type")
|
||||||
|
Loading…
Reference in New Issue
Block a user