mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: interleave streams and improve tests
This commit is contained in:
parent
942e002674
commit
16be5a14b3
@ -0,0 +1,44 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " PR for more information?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "le Business Incubator is providing a workspace"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "hd20220811-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " severely flawed and often has a substandard"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 4,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "](https://i.imgur.com/as"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1712862968,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.5-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 46,
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"total_tokens": 56
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1 @@
|
|||||||
|
"<ClientResponse(http://localhost:9483/v1/completions) [200 OK]>\n<CIMultiDictProxy('Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'x-compute-type': '1-nvidia-a10g', 'x-compute-characters': '72', 'x-accel-buffering': 'no', 'Access-Control-Allow-Origin': '*', 'Vary': 'origin', 'Vary': 'access-control-request-method', 'Vary': 'access-control-request-headers', 'Transfer-Encoding': 'chunked', 'Date': 'Thu, 11 Apr 2024 19:19:32 GMT')>\n"
|
@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " PR for flake8"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1712862926,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.5-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"prompt_tokens": 6,
|
||||||
|
"total_tokens": 11
|
||||||
|
}
|
||||||
|
}
|
@ -39,6 +39,8 @@ def test_flash_llama_completion_single_prompt(
|
|||||||
response = response.json()
|
response = response.json()
|
||||||
assert len(response["choices"]) == 1
|
assert len(response["choices"]) == 1
|
||||||
|
|
||||||
|
response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@ -46,7 +48,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||||||
json={
|
json={
|
||||||
"model": "tgi",
|
"model": "tgi",
|
||||||
"prompt": ["Say", "this", "is", "a", "test"],
|
"prompt": ["Say", "this", "is", "a", "test"],
|
||||||
"max_tokens": 5,
|
"max_tokens": 10,
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
},
|
},
|
||||||
headers=flash_llama_completion.headers,
|
headers=flash_llama_completion.headers,
|
||||||
@ -59,32 +61,43 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||||||
all_indexes.sort()
|
all_indexes.sort()
|
||||||
assert all_indexes == [0, 1, 2, 3, 4]
|
assert all_indexes == [0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_flash_llama_completion_many_prompts_stream(
|
async def test_flash_llama_completion_many_prompts_stream(
|
||||||
flash_llama_completion, response_snapshot
|
flash_llama_completion, response_snapshot
|
||||||
):
|
):
|
||||||
request = {
|
request = {
|
||||||
"model": "tgi",
|
"model": "tgi",
|
||||||
"prompt": ["Say", "this", "is", "a", "test"],
|
"prompt": [
|
||||||
"max_tokens": 5,
|
"What color is the sky?",
|
||||||
|
"Is water wet?",
|
||||||
|
"What is the capital of France?",
|
||||||
|
"def mai",
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{flash_llama_completion.base_url}/v1/completions"
|
url = f"{flash_llama_completion.base_url}/v1/completions"
|
||||||
|
|
||||||
async with ClientSession(headers=headers) as session:
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
async with session.post(url, json=request) as resp:
|
async with session.post(url, json=request) as response:
|
||||||
# iterate over the stream
|
# iterate over the stream
|
||||||
async for chunk in resp.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
# strip data: prefix and convert to json
|
# remove "data:"
|
||||||
data = json.loads(chunk.decode("utf-8")[5:])
|
chunk = chunk.decode().split("\n\n")
|
||||||
assert "choices" in data
|
# remove "data:" if present
|
||||||
assert len(data["choices"]) == 1
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
assert data["choices"][0]["index"] in [*range(len(request["prompt"]))]
|
# remove empty strings
|
||||||
|
chunk = [c for c in chunk if c]
|
||||||
|
# parse json
|
||||||
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
assert resp.status == 200
|
for c in chunk:
|
||||||
|
assert "choices" in c
|
||||||
|
assert 0 <= c["choices"][0]["index"] <= 4
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
response == response_snapshot
|
||||||
|
@ -15,7 +15,8 @@ use crate::{
|
|||||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||||
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
@ -23,8 +24,8 @@ use axum::response::{IntoResponse, Response};
|
|||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{http, Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
||||||
use futures::stream::FuturesUnordered;
|
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
|
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
@ -35,7 +36,9 @@ use std::sync::atomic::AtomicBool;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::select;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
@ -644,11 +647,16 @@ async fn completions(
|
|||||||
let mut x_accel_buffering = "no".to_string();
|
let mut x_accel_buffering = "no".to_string();
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
let response_streams = FuturesUnordered::new();
|
let mut response_streams = FuturesOrdered::new();
|
||||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||||
let model_id = info.model_id.clone();
|
let model_id = info.model_id.clone();
|
||||||
let system_fingerprint =
|
let system_fingerprint =
|
||||||
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
|
let infer_clone = infer.clone();
|
||||||
|
let compute_type_clone = compute_type.clone();
|
||||||
|
|
||||||
|
// Create a future for each generate_stream_internal call.
|
||||||
|
let generate_future = async move {
|
||||||
let on_message_callback = move |stream_token: StreamResponse| {
|
let on_message_callback = move |stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
|
|
||||||
@ -673,52 +681,91 @@ async fn completions(
|
|||||||
model: model_id.clone(),
|
model: model_id.clone(),
|
||||||
system_fingerprint: system_fingerprint.clone(),
|
system_fingerprint: system_fingerprint.clone(),
|
||||||
})
|
})
|
||||||
.map_or_else(
|
.map_or_else(|_e| Event::default(), |data| data)
|
||||||
|e| {
|
|
||||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
|
||||||
Event::default()
|
|
||||||
},
|
|
||||||
|data| data,
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
let (headers, response_stream) = generate_stream_internal(
|
|
||||||
infer.clone(),
|
let (header_tx, header_rx) = oneshot::channel();
|
||||||
compute_type.clone(),
|
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let (header_map, sse) = generate_stream_internal(
|
||||||
|
infer_clone.clone(),
|
||||||
|
compute_type_clone.clone(),
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
on_message_callback,
|
on_message_callback,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
// send and dont wait for response
|
||||||
|
let _ = header_tx.send(header_map);
|
||||||
|
|
||||||
|
// pin an emit messages to the sse_tx
|
||||||
|
let mut sse = Box::pin(sse);
|
||||||
|
while let Some(event) = sse.next().await {
|
||||||
|
sse_tx.send(event).expect("Failed to send event");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
(index, header_rx, sse_rx)
|
||||||
|
};
|
||||||
|
response_streams.push_back(generate_future);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut all_rxs = vec![];
|
||||||
|
|
||||||
|
while let Some((index, header_rx, sse_rx)) = response_streams.next().await {
|
||||||
|
all_rxs.push(sse_rx);
|
||||||
|
|
||||||
|
// get the headers from the first response of each stream
|
||||||
|
let headers = header_rx.await.expect("Failed to get headers");
|
||||||
if index == 0 {
|
if index == 0 {
|
||||||
x_compute_type = headers
|
x_compute_type = headers
|
||||||
.get("x-compute-type")
|
.get("x-compute-type")
|
||||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("unknown".to_string());
|
.unwrap_or("unknown")
|
||||||
|
.to_string();
|
||||||
x_accel_buffering = headers
|
x_accel_buffering = headers
|
||||||
.get("x-accel-buffering")
|
.get("x-accel-buffering")
|
||||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("no".to_string());
|
.unwrap_or("no")
|
||||||
|
.to_string();
|
||||||
}
|
}
|
||||||
x_compute_characters += headers
|
x_compute_characters += headers
|
||||||
.get("x-compute-characters")
|
.get("x-compute-characters")
|
||||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
response_streams.push((headers, Box::pin(response_stream)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let stream = async_stream::stream! {
|
|
||||||
for response_stream in response_streams {
|
|
||||||
let (_headers, mut inner_stream) = response_stream;
|
|
||||||
while let Some(event) = inner_stream.next().await {
|
|
||||||
yield event;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||||
|
|
||||||
|
// now sink the sse streams into a single stream and remove the ones that are done
|
||||||
|
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
|
||||||
|
loop {
|
||||||
|
let mut i = 0;
|
||||||
|
while i < all_rxs.len() {
|
||||||
|
let rx = &mut all_rxs[i];
|
||||||
|
select! {
|
||||||
|
Some(event) = rx.recv() => {
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
else => {
|
||||||
|
all_rxs.remove(i);
|
||||||
|
continue; // skip the increment to handle the next element at the same index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 1; // only increment when no element was removed
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_rxs.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||||
return Ok((headers, sse).into_response());
|
return Ok((headers, sse).into_response());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user