fix: improve headers and add streaming test

This commit is contained in:
drbh 2024-04-10 15:22:23 +00:00
parent 57606c447b
commit 942e002674
2 changed files with 117 additions and 79 deletions

View File

@ -1,11 +1,13 @@
import pytest import pytest
import requests import requests
import json
from aiohttp import ClientSession
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher): def flash_llama_completion_handle(launcher):
with launcher( with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
) as handle: ) as handle:
yield handle yield handle
@ -20,7 +22,9 @@ async def flash_llama_completion(flash_llama_completion_handle):
# method for it. Instead, we use the `requests` library to make the HTTP request directly. # method for it. Instead, we use the `requests` library to make the HTTP request directly.
def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snapshot): def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot
):
response = requests.post( response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
json={ json={
@ -36,7 +40,7 @@ def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snap
assert len(response["choices"]) == 1 assert len(response["choices"]) == 1
def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snapshot): def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post( response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
json={ json={
@ -54,3 +58,33 @@ def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snaps
all_indexes = [choice["index"] for choice in response["choices"]] all_indexes = [choice["index"] for choice in response["choices"]]
all_indexes.sort() all_indexes.sort()
assert all_indexes == [0, 1, 2, 3, 4] assert all_indexes == [0, 1, 2, 3, 4]
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):
request = {
"model": "tgi",
"prompt": ["Say", "this", "is", "a", "test"],
"max_tokens": 5,
"seed": 0,
"stream": True,
}
headers = {
"Content-Type": "application/json",
}
url = f"{flash_llama_completion.base_url}/v1/completions"
async with ClientSession(headers=headers) as session:
async with session.post(url, json=request) as resp:
# iterate over the stream
async for chunk in resp.content.iter_any():
# strip data: prefix and convert to json
data = json.loads(chunk.decode("utf-8")[5:])
assert "choices" in data
assert len(data["choices"]) == 1
assert data["choices"][0]["index"] in [*range(len(request["prompt"]))]
assert resp.status == 200

View File

@ -639,6 +639,10 @@ async fn completions(
generate_requests.push(generate_request); generate_requests.push(generate_request);
} }
let mut x_compute_type = "unknown".to_string();
let mut x_compute_characters = 0u32;
let mut x_accel_buffering = "no".to_string();
if stream { if stream {
let response_streams = FuturesUnordered::new(); let response_streams = FuturesUnordered::new();
for (index, generate_request) in generate_requests.into_iter().enumerate() { for (index, generate_request) in generate_requests.into_iter().enumerate() {
@ -677,7 +681,6 @@ async fn completions(
|data| data, |data| data,
) )
}; };
let (headers, response_stream) = generate_stream_internal( let (headers, response_stream) = generate_stream_internal(
infer.clone(), infer.clone(),
compute_type.clone(), compute_type.clone(),
@ -685,7 +688,20 @@ async fn completions(
on_message_callback, on_message_callback,
) )
.await; .await;
if index == 0 {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or("unknown".to_string());
x_accel_buffering = headers
.get("x-accel-buffering")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or("no".to_string());
}
x_compute_characters += headers
.get("x-compute-characters")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
response_streams.push((headers, Box::pin(response_stream))); response_streams.push((headers, Box::pin(response_stream)));
} }
@ -698,8 +714,13 @@ async fn completions(
} }
}; };
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
let sse = Sse::new(stream).keep_alive(KeepAlive::default()); let sse = Sse::new(stream).keep_alive(KeepAlive::default());
return Ok((HeaderMap::new(), sse).into_response()); return Ok((headers, sse).into_response());
} }
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
@ -722,11 +743,7 @@ async fn completions(
let mut completion_tokens = 0u32; let mut completion_tokens = 0u32;
let mut total_tokens = 0u32; let mut total_tokens = 0u32;
let mut headers = HeaderMap::new();
let mut x_compute_type: Option<String> = None;
let mut x_compute_time = 0u32; let mut x_compute_time = 0u32;
let mut x_compute_characters = 0u32;
let mut x_total_time = 0u32; let mut x_total_time = 0u32;
let mut x_validation_time = 0u32; let mut x_validation_time = 0u32;
let mut x_queue_time = 0u32; let mut x_queue_time = 0u32;
@ -735,44 +752,57 @@ async fn completions(
let mut x_prompt_tokens = 0u32; let mut x_prompt_tokens = 0u32;
let mut x_generated_tokens = 0u32; let mut x_generated_tokens = 0u32;
// helper closure to extract a header value or default to 0
let extract_or_zero = |headers: &HeaderMap, key: &str| {
headers
.get(key)
.and_then(|v| v.to_str().ok())
.unwrap_or("0")
.parse::<u32>()
.unwrap_or(0)
};
let choices = generate_responses let choices = generate_responses
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(index, (headers, Json(generation)))| { .map(|(index, (headers, Json(generation)))| {
let details = generation.details.unwrap_or_default(); let details = generation.details.unwrap_or_default();
if x_compute_type.is_none() { if index == 0 {
x_compute_type = Some( x_compute_type = headers
headers
.get("x-compute-type") .get("x-compute-type")
.unwrap() .and_then(|v| v.to_str().ok())
.to_str() .unwrap_or("unknown")
.unwrap() .to_string();
.to_string(),
);
} }
// update headers // accumulate headers and usage from each response
x_compute_time += extract_or_zero(&headers, "x-compute-time"); x_compute_time += headers
x_compute_characters += extract_or_zero(&headers, "x-compute-characters"); .get("x-compute-time")
x_total_time += extract_or_zero(&headers, "x-total-time"); .and_then(|v| v.to_str().ok()?.parse().ok())
x_validation_time += extract_or_zero(&headers, "x-validation-time"); .unwrap_or(0);
x_queue_time += extract_or_zero(&headers, "x-queue-time"); x_compute_characters += headers
x_inference_time += extract_or_zero(&headers, "x-inference-time"); .get("x-compute-characters")
x_time_per_token += extract_or_zero(&headers, "x-time-per-token"); .and_then(|v| v.to_str().ok()?.parse().ok())
x_prompt_tokens += extract_or_zero(&headers, "x-prompt-tokens"); .unwrap_or(0);
x_generated_tokens += extract_or_zero(&headers, "x-generated-tokens"); x_total_time += headers
.get("x-total-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_validation_time += headers
.get("x-validation-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_queue_time += headers
.get("x-queue-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_inference_time += headers
.get("x-inference-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_time_per_token += headers
.get("x-time-per-token")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_prompt_tokens += headers
.get("x-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_generated_tokens += headers
.get("x-generated-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
// update usage
prompt_tokens += details.prefill.len() as u32; prompt_tokens += details.prefill.len() as u32;
completion_tokens += details.generated_tokens; completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens; total_tokens += details.prefill.len() as u32 + details.generated_tokens;
@ -786,45 +816,6 @@ async fn completions(
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// Headers similar to `generate` but aggregated
headers.insert(
"x-compute-type",
x_compute_type
.unwrap_or("unknown".to_string())
.parse()
.unwrap(),
);
headers.insert(
"x-compute-time",
x_compute_time.to_string().parse().unwrap(),
);
headers.insert(
"x-compute-characters",
x_compute_characters.to_string().parse().unwrap(),
);
headers.insert("x-total-time", x_total_time.to_string().parse().unwrap());
headers.insert(
"x-validation-time",
x_validation_time.to_string().parse().unwrap(),
);
headers.insert("x-queue-time", x_queue_time.to_string().parse().unwrap());
headers.insert(
"x-inference-time",
x_inference_time.to_string().parse().unwrap(),
);
headers.insert(
"x-time-per-token",
x_time_per_token.to_string().parse().unwrap(),
);
headers.insert(
"x-prompt-tokens",
x_prompt_tokens.to_string().parse().unwrap(),
);
headers.insert(
"x-generated-tokens",
x_generated_tokens.to_string().parse().unwrap(),
);
let response = Completion { let response = Completion {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
@ -839,6 +830,19 @@ async fn completions(
}, },
}; };
// headers similar to `generate` but aggregated
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-total-time", x_total_time.into());
headers.insert("x-validation-time", x_validation_time.into());
headers.insert("x-queue-time", x_queue_time.into());
headers.insert("x-inference-time", x_inference_time.into());
headers.insert("x-time-per-token", x_time_per_token.into());
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
headers.insert("x-generated-tokens", x_generated_tokens.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
Ok((headers, Json(response)).into_response()) Ok((headers, Json(response)).into_response())
} }