mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: improve headers and add streaming test
This commit is contained in:
parent
57606c447b
commit
942e002674
@ -1,11 +1,13 @@
|
||||
import pytest
|
||||
import requests
|
||||
import json
|
||||
from aiohttp import ClientSession
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_completion_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
@ -20,7 +22,9 @@ async def flash_llama_completion(flash_llama_completion_handle):
|
||||
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||||
|
||||
|
||||
def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snapshot):
|
||||
def test_flash_llama_completion_single_prompt(
|
||||
flash_llama_completion, response_snapshot
|
||||
):
|
||||
response = requests.post(
|
||||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
json={
|
||||
@ -36,7 +40,7 @@ def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snap
|
||||
assert len(response["choices"]) == 1
|
||||
|
||||
|
||||
def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snapshot):
|
||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||
response = requests.post(
|
||||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
json={
|
||||
@ -54,3 +58,33 @@ def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snaps
|
||||
all_indexes = [choice["index"] for choice in response["choices"]]
|
||||
all_indexes.sort()
|
||||
assert all_indexes == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
async def test_flash_llama_completion_many_prompts_stream(
|
||||
flash_llama_completion, response_snapshot
|
||||
):
|
||||
request = {
|
||||
"model": "tgi",
|
||||
"prompt": ["Say", "this", "is", "a", "test"],
|
||||
"max_tokens": 5,
|
||||
"seed": 0,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{flash_llama_completion.base_url}/v1/completions"
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
async with session.post(url, json=request) as resp:
|
||||
# iterate over the stream
|
||||
async for chunk in resp.content.iter_any():
|
||||
# strip data: prefix and convert to json
|
||||
data = json.loads(chunk.decode("utf-8")[5:])
|
||||
assert "choices" in data
|
||||
assert len(data["choices"]) == 1
|
||||
assert data["choices"][0]["index"] in [*range(len(request["prompt"]))]
|
||||
|
||||
assert resp.status == 200
|
||||
|
@ -639,6 +639,10 @@ async fn completions(
|
||||
generate_requests.push(generate_request);
|
||||
}
|
||||
|
||||
let mut x_compute_type = "unknown".to_string();
|
||||
let mut x_compute_characters = 0u32;
|
||||
let mut x_accel_buffering = "no".to_string();
|
||||
|
||||
if stream {
|
||||
let response_streams = FuturesUnordered::new();
|
||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||
@ -677,7 +681,6 @@ async fn completions(
|
||||
|data| data,
|
||||
)
|
||||
};
|
||||
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer.clone(),
|
||||
compute_type.clone(),
|
||||
@ -685,7 +688,20 @@ async fn completions(
|
||||
on_message_callback,
|
||||
)
|
||||
.await;
|
||||
|
||||
if index == 0 {
|
||||
x_compute_type = headers
|
||||
.get("x-compute-type")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or("unknown".to_string());
|
||||
x_accel_buffering = headers
|
||||
.get("x-accel-buffering")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or("no".to_string());
|
||||
}
|
||||
x_compute_characters += headers
|
||||
.get("x-compute-characters")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
response_streams.push((headers, Box::pin(response_stream)));
|
||||
}
|
||||
|
||||
@ -698,8 +714,13 @@ async fn completions(
|
||||
}
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
|
||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||
return Ok((HeaderMap::new(), sse).into_response());
|
||||
return Ok((headers, sse).into_response());
|
||||
}
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
@ -722,11 +743,7 @@ async fn completions(
|
||||
let mut completion_tokens = 0u32;
|
||||
let mut total_tokens = 0u32;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
let mut x_compute_type: Option<String> = None;
|
||||
let mut x_compute_time = 0u32;
|
||||
let mut x_compute_characters = 0u32;
|
||||
let mut x_total_time = 0u32;
|
||||
let mut x_validation_time = 0u32;
|
||||
let mut x_queue_time = 0u32;
|
||||
@ -735,44 +752,57 @@ async fn completions(
|
||||
let mut x_prompt_tokens = 0u32;
|
||||
let mut x_generated_tokens = 0u32;
|
||||
|
||||
// helper closure to extract a header value or default to 0
|
||||
let extract_or_zero = |headers: &HeaderMap, key: &str| {
|
||||
headers
|
||||
.get(key)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("0")
|
||||
.parse::<u32>()
|
||||
.unwrap_or(0)
|
||||
};
|
||||
|
||||
let choices = generate_responses
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (headers, Json(generation)))| {
|
||||
let details = generation.details.unwrap_or_default();
|
||||
if x_compute_type.is_none() {
|
||||
x_compute_type = Some(
|
||||
headers
|
||||
if index == 0 {
|
||||
x_compute_type = headers
|
||||
.get("x-compute-type")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
);
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
}
|
||||
|
||||
// update headers
|
||||
x_compute_time += extract_or_zero(&headers, "x-compute-time");
|
||||
x_compute_characters += extract_or_zero(&headers, "x-compute-characters");
|
||||
x_total_time += extract_or_zero(&headers, "x-total-time");
|
||||
x_validation_time += extract_or_zero(&headers, "x-validation-time");
|
||||
x_queue_time += extract_or_zero(&headers, "x-queue-time");
|
||||
x_inference_time += extract_or_zero(&headers, "x-inference-time");
|
||||
x_time_per_token += extract_or_zero(&headers, "x-time-per-token");
|
||||
x_prompt_tokens += extract_or_zero(&headers, "x-prompt-tokens");
|
||||
x_generated_tokens += extract_or_zero(&headers, "x-generated-tokens");
|
||||
// accumulate headers and usage from each response
|
||||
x_compute_time += headers
|
||||
.get("x-compute-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_compute_characters += headers
|
||||
.get("x-compute-characters")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_total_time += headers
|
||||
.get("x-total-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_validation_time += headers
|
||||
.get("x-validation-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_queue_time += headers
|
||||
.get("x-queue-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_inference_time += headers
|
||||
.get("x-inference-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_time_per_token += headers
|
||||
.get("x-time-per-token")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_prompt_tokens += headers
|
||||
.get("x-prompt-tokens")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_generated_tokens += headers
|
||||
.get("x-generated-tokens")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
// update usage
|
||||
prompt_tokens += details.prefill.len() as u32;
|
||||
completion_tokens += details.generated_tokens;
|
||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||
@ -786,45 +816,6 @@ async fn completions(
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Headers similar to `generate` but aggregated
|
||||
headers.insert(
|
||||
"x-compute-type",
|
||||
x_compute_type
|
||||
.unwrap_or("unknown".to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-compute-time",
|
||||
x_compute_time.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-compute-characters",
|
||||
x_compute_characters.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert("x-total-time", x_total_time.to_string().parse().unwrap());
|
||||
headers.insert(
|
||||
"x-validation-time",
|
||||
x_validation_time.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert("x-queue-time", x_queue_time.to_string().parse().unwrap());
|
||||
headers.insert(
|
||||
"x-inference-time",
|
||||
x_inference_time.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-time-per-token",
|
||||
x_time_per_token.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-prompt-tokens",
|
||||
x_prompt_tokens.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-generated-tokens",
|
||||
x_generated_tokens.to_string().parse().unwrap(),
|
||||
);
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
@ -839,6 +830,19 @@ async fn completions(
|
||||
},
|
||||
};
|
||||
|
||||
// headers similar to `generate` but aggregated
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||
headers.insert("x-total-time", x_total_time.into());
|
||||
headers.insert("x-validation-time", x_validation_time.into());
|
||||
headers.insert("x-queue-time", x_queue_time.into());
|
||||
headers.insert("x-inference-time", x_inference_time.into());
|
||||
headers.insert("x-time-per-token", x_time_per_token.into());
|
||||
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
||||
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user