mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: improve headers and add streaming test
This commit is contained in:
parent
57606c447b
commit
942e002674
@ -1,11 +1,13 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
import json
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_completion_handle(launcher):
|
def flash_llama_completion_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
) as handle:
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
@ -20,7 +22,9 @@ async def flash_llama_completion(flash_llama_completion_handle):
|
|||||||
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||||||
|
|
||||||
|
|
||||||
def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_single_prompt(
|
||||||
|
flash_llama_completion, response_snapshot
|
||||||
|
):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{flash_llama_completion.base_url}/v1/completions",
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||||||
json={
|
json={
|
||||||
@ -36,7 +40,7 @@ def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snap
|
|||||||
assert len(response["choices"]) == 1
|
assert len(response["choices"]) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{flash_llama_completion.base_url}/v1/completions",
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||||||
json={
|
json={
|
||||||
@ -54,3 +58,33 @@ def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snaps
|
|||||||
all_indexes = [choice["index"] for choice in response["choices"]]
|
all_indexes = [choice["index"] for choice in response["choices"]]
|
||||||
all_indexes.sort()
|
all_indexes.sort()
|
||||||
assert all_indexes == [0, 1, 2, 3, 4]
|
assert all_indexes == [0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_llama_completion_many_prompts_stream(
|
||||||
|
flash_llama_completion, response_snapshot
|
||||||
|
):
|
||||||
|
request = {
|
||||||
|
"model": "tgi",
|
||||||
|
"prompt": ["Say", "this", "is", "a", "test"],
|
||||||
|
"max_tokens": 5,
|
||||||
|
"seed": 0,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{flash_llama_completion.base_url}/v1/completions"
|
||||||
|
|
||||||
|
async with ClientSession(headers=headers) as session:
|
||||||
|
async with session.post(url, json=request) as resp:
|
||||||
|
# iterate over the stream
|
||||||
|
async for chunk in resp.content.iter_any():
|
||||||
|
# strip data: prefix and convert to json
|
||||||
|
data = json.loads(chunk.decode("utf-8")[5:])
|
||||||
|
assert "choices" in data
|
||||||
|
assert len(data["choices"]) == 1
|
||||||
|
assert data["choices"][0]["index"] in [*range(len(request["prompt"]))]
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
@ -639,6 +639,10 @@ async fn completions(
|
|||||||
generate_requests.push(generate_request);
|
generate_requests.push(generate_request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut x_compute_type = "unknown".to_string();
|
||||||
|
let mut x_compute_characters = 0u32;
|
||||||
|
let mut x_accel_buffering = "no".to_string();
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
let response_streams = FuturesUnordered::new();
|
let response_streams = FuturesUnordered::new();
|
||||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||||
@ -677,7 +681,6 @@ async fn completions(
|
|||||||
|data| data,
|
|data| data,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let (headers, response_stream) = generate_stream_internal(
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
infer.clone(),
|
infer.clone(),
|
||||||
compute_type.clone(),
|
compute_type.clone(),
|
||||||
@ -685,7 +688,20 @@ async fn completions(
|
|||||||
on_message_callback,
|
on_message_callback,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
if index == 0 {
|
||||||
|
x_compute_type = headers
|
||||||
|
.get("x-compute-type")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or("unknown".to_string());
|
||||||
|
x_accel_buffering = headers
|
||||||
|
.get("x-accel-buffering")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or("no".to_string());
|
||||||
|
}
|
||||||
|
x_compute_characters += headers
|
||||||
|
.get("x-compute-characters")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
response_streams.push((headers, Box::pin(response_stream)));
|
response_streams.push((headers, Box::pin(response_stream)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -698,8 +714,13 @@ async fn completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||||
|
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||||
|
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||||
|
|
||||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||||
return Ok((HeaderMap::new(), sse).into_response());
|
return Ok((headers, sse).into_response());
|
||||||
}
|
}
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
@ -722,11 +743,7 @@ async fn completions(
|
|||||||
let mut completion_tokens = 0u32;
|
let mut completion_tokens = 0u32;
|
||||||
let mut total_tokens = 0u32;
|
let mut total_tokens = 0u32;
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
|
|
||||||
let mut x_compute_type: Option<String> = None;
|
|
||||||
let mut x_compute_time = 0u32;
|
let mut x_compute_time = 0u32;
|
||||||
let mut x_compute_characters = 0u32;
|
|
||||||
let mut x_total_time = 0u32;
|
let mut x_total_time = 0u32;
|
||||||
let mut x_validation_time = 0u32;
|
let mut x_validation_time = 0u32;
|
||||||
let mut x_queue_time = 0u32;
|
let mut x_queue_time = 0u32;
|
||||||
@ -735,44 +752,57 @@ async fn completions(
|
|||||||
let mut x_prompt_tokens = 0u32;
|
let mut x_prompt_tokens = 0u32;
|
||||||
let mut x_generated_tokens = 0u32;
|
let mut x_generated_tokens = 0u32;
|
||||||
|
|
||||||
// helper closure to extract a header value or default to 0
|
|
||||||
let extract_or_zero = |headers: &HeaderMap, key: &str| {
|
|
||||||
headers
|
|
||||||
.get(key)
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.unwrap_or("0")
|
|
||||||
.parse::<u32>()
|
|
||||||
.unwrap_or(0)
|
|
||||||
};
|
|
||||||
|
|
||||||
let choices = generate_responses
|
let choices = generate_responses
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(index, (headers, Json(generation)))| {
|
.map(|(index, (headers, Json(generation)))| {
|
||||||
let details = generation.details.unwrap_or_default();
|
let details = generation.details.unwrap_or_default();
|
||||||
if x_compute_type.is_none() {
|
if index == 0 {
|
||||||
x_compute_type = Some(
|
x_compute_type = headers
|
||||||
headers
|
.get("x-compute-type")
|
||||||
.get("x-compute-type")
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap()
|
.unwrap_or("unknown")
|
||||||
.to_str()
|
.to_string();
|
||||||
.unwrap()
|
|
||||||
.to_string(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update headers
|
// accumulate headers and usage from each response
|
||||||
x_compute_time += extract_or_zero(&headers, "x-compute-time");
|
x_compute_time += headers
|
||||||
x_compute_characters += extract_or_zero(&headers, "x-compute-characters");
|
.get("x-compute-time")
|
||||||
x_total_time += extract_or_zero(&headers, "x-total-time");
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
x_validation_time += extract_or_zero(&headers, "x-validation-time");
|
.unwrap_or(0);
|
||||||
x_queue_time += extract_or_zero(&headers, "x-queue-time");
|
x_compute_characters += headers
|
||||||
x_inference_time += extract_or_zero(&headers, "x-inference-time");
|
.get("x-compute-characters")
|
||||||
x_time_per_token += extract_or_zero(&headers, "x-time-per-token");
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
x_prompt_tokens += extract_or_zero(&headers, "x-prompt-tokens");
|
.unwrap_or(0);
|
||||||
x_generated_tokens += extract_or_zero(&headers, "x-generated-tokens");
|
x_total_time += headers
|
||||||
|
.get("x-total-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_validation_time += headers
|
||||||
|
.get("x-validation-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_queue_time += headers
|
||||||
|
.get("x-queue-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_inference_time += headers
|
||||||
|
.get("x-inference-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_time_per_token += headers
|
||||||
|
.get("x-time-per-token")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_prompt_tokens += headers
|
||||||
|
.get("x-prompt-tokens")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_generated_tokens += headers
|
||||||
|
.get("x-generated-tokens")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
// update usage
|
|
||||||
prompt_tokens += details.prefill.len() as u32;
|
prompt_tokens += details.prefill.len() as u32;
|
||||||
completion_tokens += details.generated_tokens;
|
completion_tokens += details.generated_tokens;
|
||||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||||
@ -786,45 +816,6 @@ async fn completions(
|
|||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
// Headers similar to `generate` but aggregated
|
|
||||||
headers.insert(
|
|
||||||
"x-compute-type",
|
|
||||||
x_compute_type
|
|
||||||
.unwrap_or("unknown".to_string())
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
"x-compute-time",
|
|
||||||
x_compute_time.to_string().parse().unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
"x-compute-characters",
|
|
||||||
x_compute_characters.to_string().parse().unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert("x-total-time", x_total_time.to_string().parse().unwrap());
|
|
||||||
headers.insert(
|
|
||||||
"x-validation-time",
|
|
||||||
x_validation_time.to_string().parse().unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert("x-queue-time", x_queue_time.to_string().parse().unwrap());
|
|
||||||
headers.insert(
|
|
||||||
"x-inference-time",
|
|
||||||
x_inference_time.to_string().parse().unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
"x-time-per-token",
|
|
||||||
x_time_per_token.to_string().parse().unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
"x-prompt-tokens",
|
|
||||||
x_prompt_tokens.to_string().parse().unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
"x-generated-tokens",
|
|
||||||
x_generated_tokens.to_string().parse().unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let response = Completion {
|
let response = Completion {
|
||||||
id: "".to_string(),
|
id: "".to_string(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
@ -839,6 +830,19 @@ async fn completions(
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// headers similar to `generate` but aggregated
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||||
|
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||||
|
headers.insert("x-total-time", x_total_time.into());
|
||||||
|
headers.insert("x-validation-time", x_validation_time.into());
|
||||||
|
headers.insert("x-queue-time", x_queue_time.into());
|
||||||
|
headers.insert("x-inference-time", x_inference_time.into());
|
||||||
|
headers.insert("x-time-per-token", x_time_per_token.into());
|
||||||
|
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
||||||
|
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
||||||
|
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||||
|
|
||||||
Ok((headers, Json(response)).into_response())
|
Ok((headers, Json(response)).into_response())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user