diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index d14c98f0..d410ad2f 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -1,11 +1,13 @@ import pytest import requests +import json +from aiohttp import ClientSession @pytest.fixture(scope="module") def flash_llama_completion_handle(launcher): with launcher( - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ) as handle: yield handle @@ -20,7 +22,9 @@ async def flash_llama_completion(flash_llama_completion_handle): # method for it. Instead, we use the `requests` library to make the HTTP request directly. -def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snapshot): +def test_flash_llama_completion_single_prompt( + flash_llama_completion, response_snapshot +): response = requests.post( f"{flash_llama_completion.base_url}/v1/completions", json={ @@ -36,7 +40,7 @@ def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snap assert len(response["choices"]) == 1 -def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snapshot): +def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): response = requests.post( f"{flash_llama_completion.base_url}/v1/completions", json={ @@ -54,3 +58,33 @@ def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snaps all_indexes = [choice["index"] for choice in response["choices"]] all_indexes.sort() assert all_indexes == [0, 1, 2, 3, 4] + + +async def test_flash_llama_completion_many_prompts_stream( + flash_llama_completion, response_snapshot +): + request = { + "model": "tgi", + "prompt": ["Say", "this", "is", "a", "test"], + "max_tokens": 5, + "seed": 0, + "stream": True, + } + + headers = { + "Content-Type": "application/json", + } + + url = f"{flash_llama_completion.base_url}/v1/completions" + + async with ClientSession(headers=headers) as session: + async with session.post(url, json=request) as resp: + # iterate over the stream + async for chunk in resp.content.iter_any(): + # strip data: prefix and convert to json + data = json.loads(chunk.decode("utf-8")[5:]) + assert "choices" in data + assert len(data["choices"]) == 1 + assert data["choices"][0]["index"] in [*range(len(request["prompt"]))] + + assert resp.status == 200 diff --git a/router/src/server.rs b/router/src/server.rs index 51851dfd..3ea6813e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -639,6 +639,10 @@ async fn completions( generate_requests.push(generate_request); } + let mut x_compute_type = "unknown".to_string(); + let mut x_compute_characters = 0u32; + let mut x_accel_buffering = "no".to_string(); + if stream { let response_streams = FuturesUnordered::new(); for (index, generate_request) in generate_requests.into_iter().enumerate() { @@ -677,7 +681,6 @@ async fn completions( |data| data, ) }; - let (headers, response_stream) = generate_stream_internal( infer.clone(), compute_type.clone(), @@ -685,7 +688,20 @@ async fn completions( on_message_callback, ) .await; - + if index == 0 { + x_compute_type = headers + .get("x-compute-type") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or("unknown".to_string()); + x_accel_buffering = headers + .get("x-accel-buffering") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or("no".to_string()); + } + x_compute_characters += headers + .get("x-compute-characters") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); response_streams.push((headers, Box::pin(response_stream))); } @@ -698,8 +714,13 @@ async fn completions( } }; + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + headers.insert("x-compute-characters", x_compute_characters.into()); + headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); - return Ok((HeaderMap::new(), sse).into_response()); + return Ok((headers, sse).into_response()); } let current_time = std::time::SystemTime::now() @@ -722,11 +743,7 @@ async fn completions( let mut completion_tokens = 0u32; let mut total_tokens = 0u32; - let mut headers = HeaderMap::new(); - - let mut x_compute_type: Option = None; let mut x_compute_time = 0u32; - let mut x_compute_characters = 0u32; let mut x_total_time = 0u32; let mut x_validation_time = 0u32; let mut x_queue_time = 0u32; @@ -735,44 +752,57 @@ async fn completions( let mut x_prompt_tokens = 0u32; let mut x_generated_tokens = 0u32; - // helper closure to extract a header value or default to 0 - let extract_or_zero = |headers: &HeaderMap, key: &str| { - headers - .get(key) - .and_then(|v| v.to_str().ok()) - .unwrap_or("0") - .parse::() - .unwrap_or(0) - }; - let choices = generate_responses .into_iter() .enumerate() .map(|(index, (headers, Json(generation)))| { let details = generation.details.unwrap_or_default(); - if x_compute_type.is_none() { - x_compute_type = Some( - headers - .get("x-compute-type") - .unwrap() - .to_str() - .unwrap() - .to_string(), - ); + if index == 0 { + x_compute_type = headers + .get("x-compute-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown") + .to_string(); } - // update headers - x_compute_time += extract_or_zero(&headers, "x-compute-time"); - x_compute_characters += extract_or_zero(&headers, "x-compute-characters"); - x_total_time += extract_or_zero(&headers, "x-total-time"); - x_validation_time += extract_or_zero(&headers, "x-validation-time"); - x_queue_time += extract_or_zero(&headers, "x-queue-time"); - x_inference_time += extract_or_zero(&headers, "x-inference-time"); - x_time_per_token += extract_or_zero(&headers, "x-time-per-token"); - x_prompt_tokens += extract_or_zero(&headers, "x-prompt-tokens"); - x_generated_tokens += extract_or_zero(&headers, "x-generated-tokens"); + // accumulate headers and usage from each response + x_compute_time += headers + .get("x-compute-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_compute_characters += headers + .get("x-compute-characters") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_total_time += headers + .get("x-total-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_validation_time += headers + .get("x-validation-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_queue_time += headers + .get("x-queue-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_inference_time += headers + .get("x-inference-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_time_per_token += headers + .get("x-time-per-token") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_prompt_tokens += headers + .get("x-prompt-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_generated_tokens += headers + .get("x-generated-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); - // update usage prompt_tokens += details.prefill.len() as u32; completion_tokens += details.generated_tokens; total_tokens += details.prefill.len() as u32 + details.generated_tokens; @@ -786,45 +816,6 @@ async fn completions( }) .collect::>(); - // Headers similar to `generate` but aggregated - headers.insert( - "x-compute-type", - x_compute_type - .unwrap_or("unknown".to_string()) - .parse() - .unwrap(), - ); - headers.insert( - "x-compute-time", - x_compute_time.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - x_compute_characters.to_string().parse().unwrap(), - ); - headers.insert("x-total-time", x_total_time.to_string().parse().unwrap()); - headers.insert( - "x-validation-time", - x_validation_time.to_string().parse().unwrap(), - ); - headers.insert("x-queue-time", x_queue_time.to_string().parse().unwrap()); - headers.insert( - "x-inference-time", - x_inference_time.to_string().parse().unwrap(), - ); - headers.insert( - "x-time-per-token", - x_time_per_token.to_string().parse().unwrap(), - ); - headers.insert( - "x-prompt-tokens", - x_prompt_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-generated-tokens", - x_generated_tokens.to_string().parse().unwrap(), - ); - let response = Completion { id: "".to_string(), object: "text_completion".to_string(), @@ -839,6 +830,19 @@ async fn completions( }, }; + // headers similar to `generate` but aggregated + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + headers.insert("x-compute-characters", x_compute_characters.into()); + headers.insert("x-total-time", x_total_time.into()); + headers.insert("x-validation-time", x_validation_time.into()); + headers.insert("x-queue-time", x_queue_time.into()); + headers.insert("x-inference-time", x_inference_time.into()); + headers.insert("x-time-per-token", x_time_per_token.into()); + headers.insert("x-prompt-tokens", x_prompt_tokens.into()); + headers.insert("x-generated-tokens", x_generated_tokens.into()); + headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + Ok((headers, Json(response)).into_response()) }