feat: interleave streams and improve tests

This commit is contained in:
drbh 2024-04-11 19:26:20 +00:00
parent 942e002674
commit 16be5a14b3
5 changed files with 192 additions and 67 deletions

View File

@ -0,0 +1,44 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"text": " PR for more information?"
},
{
"finish_reason": "length",
"index": 1,
"logprobs": null,
"text": "le Business Incubator is providing a workspace"
},
{
"finish_reason": "length",
"index": 2,
"logprobs": null,
"text": "hd20220811-"
},
{
"finish_reason": "length",
"index": 3,
"logprobs": null,
"text": " severely flawed and often has a substandard"
},
{
"finish_reason": "length",
"index": 4,
"logprobs": null,
"text": "](https://i.imgur.com/as"
}
],
"created": 1712862968,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"usage": {
"completion_tokens": 46,
"prompt_tokens": 10,
"total_tokens": 56
}
}

View File

@ -0,0 +1 @@
"<ClientResponse(http://localhost:9483/v1/completions) [200 OK]>\n<CIMultiDictProxy('Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'x-compute-type': '1-nvidia-a10g', 'x-compute-characters': '72', 'x-accel-buffering': 'no', 'Access-Control-Allow-Origin': '*', 'Vary': 'origin', 'Vary': 'access-control-request-method', 'Vary': 'access-control-request-headers', 'Transfer-Encoding': 'chunked', 'Date': 'Thu, 11 Apr 2024 19:19:32 GMT')>\n"

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " PR for flake8"
}
],
"created": 1712862926,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"usage": {
"completion_tokens": 5,
"prompt_tokens": 6,
"total_tokens": 11
}
}

View File

@ -39,6 +39,8 @@ def test_flash_llama_completion_single_prompt(
response = response.json()
assert len(response["choices"]) == 1
response == response_snapshot
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(
@ -46,7 +48,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
json={
"model": "tgi",
"prompt": ["Say", "this", "is", "a", "test"],
"max_tokens": 5,
"max_tokens": 10,
"seed": 0,
},
headers=flash_llama_completion.headers,
@ -59,32 +61,43 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
all_indexes.sort()
assert all_indexes == [0, 1, 2, 3, 4]
response == response_snapshot
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):
request = {
"model": "tgi",
"prompt": ["Say", "this", "is", "a", "test"],
"max_tokens": 5,
"prompt": [
"What color is the sky?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"stream": True,
}
headers = {
"Content-Type": "application/json",
}
url = f"{flash_llama_completion.base_url}/v1/completions"
async with ClientSession(headers=headers) as session:
async with session.post(url, json=request) as resp:
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in resp.content.iter_any():
# strip data: prefix and convert to json
data = json.loads(chunk.decode("utf-8")[5:])
assert "choices" in data
assert len(data["choices"]) == 1
assert data["choices"][0]["index"] in [*range(len(request["prompt"]))]
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# parse json
chunk = [json.loads(c) for c in chunk]
assert resp.status == 200
for c in chunk:
assert "choices" in c
assert 0 <= c["choices"][0]["index"] <= 4
assert response.status == 200
response == response_snapshot

View File

@ -15,7 +15,8 @@ use crate::{
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
};
use crate::{FunctionDefinition, ToolCall, ToolType};
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
@ -23,8 +24,8 @@ use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use futures::stream::FuturesUnordered;
use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream;
use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
@ -35,7 +36,9 @@ use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal;
use tokio::sync::oneshot;
use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument};
@ -644,81 +647,125 @@ async fn completions(
let mut x_accel_buffering = "no".to_string();
if stream {
let response_streams = FuturesUnordered::new();
let mut response_streams = FuturesOrdered::new();
for (index, generate_request) in generate_requests.into_iter().enumerate() {
let model_id = info.model_id.clone();
let system_fingerprint =
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
// Create a future for each generate_stream_internal call.
let generate_future = async move {
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
event
.json_data(CompletionCompleteChunk {
id: "".to_string(),
object: "text_completion".to_string(),
created: current_time,
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
choices: vec![CompletionComplete {
finish_reason: "".to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
event
.json_data(CompletionCompleteChunk {
id: "".to_string(),
object: "text_completion".to_string(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
})
.map_or_else(
|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
},
|data| data,
choices: vec![CompletionComplete {
finish_reason: "".to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
})
.map_or_else(|_e| Event::default(), |data| data)
};
let (header_tx, header_rx) = oneshot::channel();
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
let (header_map, sse) = generate_stream_internal(
infer_clone.clone(),
compute_type_clone.clone(),
Json(generate_request),
on_message_callback,
)
.await;
// send and dont wait for response
let _ = header_tx.send(header_map);
// pin an emit messages to the sse_tx
let mut sse = Box::pin(sse);
while let Some(event) = sse.next().await {
sse_tx.send(event).expect("Failed to send event");
}
});
(index, header_rx, sse_rx)
};
let (headers, response_stream) = generate_stream_internal(
infer.clone(),
compute_type.clone(),
Json(generate_request),
on_message_callback,
)
.await;
response_streams.push_back(generate_future);
}
let mut all_rxs = vec![];
while let Some((index, header_rx, sse_rx)) = response_streams.next().await {
all_rxs.push(sse_rx);
// get the headers from the first response of each stream
let headers = header_rx.await.expect("Failed to get headers");
if index == 0 {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or("unknown".to_string());
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
x_accel_buffering = headers
.get("x-accel-buffering")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or("no".to_string());
.and_then(|v| v.to_str().ok())
.unwrap_or("no")
.to_string();
}
x_compute_characters += headers
.get("x-compute-characters")
.and_then(|v| v.to_str().ok()?.parse().ok())
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok())
.unwrap_or(0);
response_streams.push((headers, Box::pin(response_stream)));
}
let stream = async_stream::stream! {
for response_stream in response_streams {
let (_headers, mut inner_stream) = response_stream;
while let Some(event) = inner_stream.next().await {
yield event;
}
}
};
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
// now sink the sse streams into a single stream and remove the ones that are done
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
loop {
let mut i = 0;
while i < all_rxs.len() {
let rx = &mut all_rxs[i];
select! {
Some(event) = rx.recv() => {
yield event;
}
else => {
all_rxs.remove(i);
continue; // skip the increment to handle the next element at the same index
}
}
i += 1; // only increment when no element was removed
}
if all_rxs.is_empty() {
break;
}
}
};
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
return Ok((headers, sse).into_response());
}