diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json new file mode 100644 index 00000000..4aec05b8 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json @@ -0,0 +1,44 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "text": " PR for more information?" + }, + { + "finish_reason": "length", + "index": 1, + "logprobs": null, + "text": "le Business Incubator is providing a workspace" + }, + { + "finish_reason": "length", + "index": 2, + "logprobs": null, + "text": "hd20220811-" + }, + { + "finish_reason": "length", + "index": 3, + "logprobs": null, + "text": " severely flawed and often has a substandard" + }, + { + "finish_reason": "length", + "index": 4, + "logprobs": null, + "text": "](https://i.imgur.com/as" + } + ], + "created": 1712862968, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.5-native", + "usage": { + "completion_tokens": 46, + "prompt_tokens": 10, + "total_tokens": 56 + } +} diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json new file mode 100644 index 00000000..bdea1b77 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json @@ -0,0 +1 @@ +"\n\n" diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json new file mode 100644 index 00000000..51193b0c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": " PR for flake8" + } + ], + "created": 1712862926, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.5-native", + "usage": { + "completion_tokens": 5, + "prompt_tokens": 6, + "total_tokens": 11 + } +} diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index d410ad2f..7d1f2c1d 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -39,6 +39,8 @@ def test_flash_llama_completion_single_prompt( response = response.json() assert len(response["choices"]) == 1 + response == response_snapshot + def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): response = requests.post( @@ -46,7 +48,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn json={ "model": "tgi", "prompt": ["Say", "this", "is", "a", "test"], - "max_tokens": 5, + "max_tokens": 10, "seed": 0, }, headers=flash_llama_completion.headers, @@ -59,32 +61,43 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn all_indexes.sort() assert all_indexes == [0, 1, 2, 3, 4] + response == response_snapshot + async def test_flash_llama_completion_many_prompts_stream( flash_llama_completion, response_snapshot ): request = { "model": "tgi", - "prompt": ["Say", "this", "is", "a", "test"], - "max_tokens": 5, + "prompt": [ + "What color is the sky?", + "Is water wet?", + "What is the capital of France?", + "def mai", + ], + "max_tokens": 10, "seed": 0, "stream": True, } - headers = { - "Content-Type": "application/json", - } - url = f"{flash_llama_completion.base_url}/v1/completions" - async with ClientSession(headers=headers) as session: - async with session.post(url, json=request) as resp: + async with ClientSession(headers=flash_llama_completion.headers) as session: + async with session.post(url, json=request) as response: # iterate over the stream - async for chunk in resp.content.iter_any(): - # strip data: prefix and convert to json - data = json.loads(chunk.decode("utf-8")[5:]) - assert "choices" in data - assert len(data["choices"]) == 1 - assert data["choices"][0]["index"] in [*range(len(request["prompt"]))] + async for chunk in response.content.iter_any(): + # remove "data:" + chunk = chunk.decode().split("\n\n") + # remove "data:" if present + chunk = [c.replace("data:", "") for c in chunk] + # remove empty strings + chunk = [c for c in chunk if c] + # parse json + chunk = [json.loads(c) for c in chunk] - assert resp.status == 200 + for c in chunk: + assert "choices" in c + assert 0 <= c["choices"][0]["index"] <= 4 + + assert response.status == 200 + response == response_snapshot diff --git a/router/src/server.rs b/router/src/server.rs index 3ea6813e..d804da2f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -15,7 +15,8 @@ use crate::{ ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, ToolCall, ToolType}; +use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; +use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -23,8 +24,8 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; -use futures::stream::FuturesUnordered; use futures::stream::StreamExt; +use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; @@ -35,7 +36,9 @@ use std::sync::atomic::AtomicBool; use std::sync::Arc; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; +use tokio::select; use tokio::signal; +use tokio::sync::oneshot; use tokio::time::Instant; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; @@ -644,81 +647,125 @@ async fn completions( let mut x_accel_buffering = "no".to_string(); if stream { - let response_streams = FuturesUnordered::new(); + let mut response_streams = FuturesOrdered::new(); for (index, generate_request) in generate_requests.into_iter().enumerate() { let model_id = info.model_id.clone(); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); + let infer_clone = infer.clone(); + let compute_type_clone = compute_type.clone(); - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); + // Create a future for each generate_stream_internal call. + let generate_future = async move { + let on_message_callback = move |stream_token: StreamResponse| { + let event = Event::default(); - event - .json_data(CompletionCompleteChunk { - id: "".to_string(), - object: "text_completion".to_string(), - created: current_time, + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); - choices: vec![CompletionComplete { - finish_reason: "".to_string(), - index: index as u32, - logprobs: None, - text: stream_token.token.text, - }], + event + .json_data(CompletionCompleteChunk { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, - model: model_id.clone(), - system_fingerprint: system_fingerprint.clone(), - }) - .map_or_else( - |e| { - println!("Failed to serialize ChatCompletionChunk: {:?}", e); - Event::default() - }, - |data| data, + choices: vec![CompletionComplete { + finish_reason: "".to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + }) + .map_or_else(|_e| Event::default(), |data| data) + }; + + let (header_tx, header_rx) = oneshot::channel(); + let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); + + tokio::spawn(async move { + let (header_map, sse) = generate_stream_internal( + infer_clone.clone(), + compute_type_clone.clone(), + Json(generate_request), + on_message_callback, ) + .await; + + // send and dont wait for response + let _ = header_tx.send(header_map); + + // pin an emit messages to the sse_tx + let mut sse = Box::pin(sse); + while let Some(event) = sse.next().await { + sse_tx.send(event).expect("Failed to send event"); + } + }); + + (index, header_rx, sse_rx) }; - let (headers, response_stream) = generate_stream_internal( - infer.clone(), - compute_type.clone(), - Json(generate_request), - on_message_callback, - ) - .await; + response_streams.push_back(generate_future); + } + + let mut all_rxs = vec![]; + + while let Some((index, header_rx, sse_rx)) = response_streams.next().await { + all_rxs.push(sse_rx); + + // get the headers from the first response of each stream + let headers = header_rx.await.expect("Failed to get headers"); if index == 0 { x_compute_type = headers .get("x-compute-type") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or("unknown".to_string()); + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown") + .to_string(); x_accel_buffering = headers .get("x-accel-buffering") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or("no".to_string()); + .and_then(|v| v.to_str().ok()) + .unwrap_or("no") + .to_string(); } x_compute_characters += headers .get("x-compute-characters") - .and_then(|v| v.to_str().ok()?.parse().ok()) + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse().ok()) .unwrap_or(0); - response_streams.push((headers, Box::pin(response_stream))); } - let stream = async_stream::stream! { - for response_stream in response_streams { - let (_headers, mut inner_stream) = response_stream; - while let Some(event) = inner_stream.next().await { - yield event; - } - } - }; - let mut headers = HeaderMap::new(); headers.insert("x-compute-type", x_compute_type.parse().unwrap()); headers.insert("x-compute-characters", x_compute_characters.into()); headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + // now sink the sse streams into a single stream and remove the ones that are done + let stream: AsyncStream, _> = async_stream::stream! { + loop { + let mut i = 0; + while i < all_rxs.len() { + let rx = &mut all_rxs[i]; + select! { + Some(event) = rx.recv() => { + yield event; + } + else => { + all_rxs.remove(i); + continue; // skip the increment to handle the next element at the same index + } + } + i += 1; // only increment when no element was removed + } + + if all_rxs.is_empty() { + break; + } + } + }; + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); return Ok((headers, sse).into_response()); }