From a8ddf45c11ddf41ae64541d3542b305d3014599f Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 14:57:49 +0100 Subject: [PATCH] cleanup --- router/Cargo.toml | 10 +++- router/src/batcher.rs | 17 ++----- router/src/server.rs | 105 +++++++++++++++++++++++------------------- 3 files changed, 70 insertions(+), 62 deletions(-) diff --git a/router/Cargo.toml b/router/Cargo.toml index 94b3be1c..cf9091b6 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -23,8 +23,14 @@ serde = "1.0.145" serde_json = "1.0.85" thiserror = "1.0.37" tokenizers = "0.13.0" -tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] } +tokio = { version = "1.21.1", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", + "net", +] } tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["json"] } async-stream = "0.3.3" - diff --git a/router/src/batcher.rs b/router/src/batcher.rs index da74a6db..5c0681ee 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -8,9 +8,9 @@ use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; -use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate}; +use text_generation_client::{Batch, ClientError, GeneratedText, Intermediate, ShardedClient}; use thiserror::Error; -use tokio::sync::{oneshot, Notify, mpsc}; +use tokio::sync::{mpsc, oneshot, Notify}; use tokio::time::Instant; use tracing::instrument; @@ -74,13 +74,6 @@ impl Batcher { // Notify the background task that we have a new entry in the database that needs // to be batched self.shared.batching_task.notify_one(); - - // // Await on the response from the background task - // // We can safely unwrap as the background task will never drop the sender - // response_rx - // .await - // .unwrap() - // .map_err(|err| InferError::GenerationError(err.to_string())) } /// Add a new request to the database and return a future that will generate the text @@ -217,9 +210,9 @@ fn send_generated(finished: Vec, intermediates: Vec if let Some(tx) = &entry.intermediate_tx { - // unwrap_or is valid here as we don't care if the receiver is gone. - tx.send(Ok(Some(intermediate))).unwrap_or(()); - } + // unwrap_or is valid here as we don't care if the receiver is gone. + tx.send(Ok(Some(intermediate))).unwrap_or(()); + } }); finished.into_iter().for_each(|output| { diff --git a/router/src/server.rs b/router/src/server.rs index 14418a28..e60d94c8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,17 +8,17 @@ use axum::routing::{get, post}; use axum::{Json, Router}; use std::net::SocketAddr; use std::sync::Arc; -use text_generation_client::{ShardedClient, IntermediateEvent}; +use text_generation_client::{IntermediateEvent, ShardedClient}; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::Semaphore; +use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::instrument; -use tokio::sync::{oneshot, mpsc}; use axum::response::sse::{Event, KeepAlive, Sse}; -use std::convert::Infallible; use futures::stream::Stream; +use std::convert::Infallible; // Server shared state #[derive(Clone)] @@ -81,62 +81,71 @@ async fn generate_stream( let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = oneshot::channel(); - let (input_length, validated_request) = - state.validation.validate(req.0).await.map_err(|err| { + let (input_length, validated_request) = state + .validation + .validate(req.0) + .await + .map_err(|err| { tracing::error!("{}", err.to_string()); err - }).unwrap(); + }) + .unwrap(); // Inference - state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx); - - let stream = async_stream::stream! { - while let Some(item) = intermediate_rx.recv().await { - match item { - Ok(item) => { - match item { - Some(item) => { - let event_data = IntermediateEvent { - token: item.token, - token_id: item.token_id, - logprob: item.logprob, - }; - let stream_event = StreamEvent { - is_end: false, - event: Some(event_data), - generated_text: None, - }; - yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); - } - None => { - break - } + state.batcher.infer_stream( + input_length, + validated_request, + intermediate_tx, + response_tx, + ); + + let stream = async_stream::stream! { + while let Some(item) = intermediate_rx.recv().await { + match item { + Ok(item) => { + match item { + Some(item) => { + let event_data = IntermediateEvent { + token: item.token, + token_id: item.token_id, + logprob: item.logprob, + }; + let stream_event = StreamEvent { + is_end: false, + event: Some(event_data), + generated_text: None, + }; + yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); + } + None => { + break } } - Err(err) => { - yield Ok(Event::default().data(err.to_string())); - } - } - } - let response = response_rx.await.unwrap(); - match response { - Ok(response) => { - let response = GeneratedText { - generated_text: response.output_text, - details: None, - }; - let stream_event = StreamEvent { - is_end: true, - event: None, - generated_text: Some(response), - }; - yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); } Err(err) => { yield Ok(Event::default().data(err.to_string())); } } - }; + } + let response = response_rx.await.unwrap(); + match response { + Ok(response) => { + let response = GeneratedText { + generated_text: response.output_text, + details: None, + }; + let stream_event = StreamEvent { + is_end: true, + event: None, + generated_text: Some(response), + }; + yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); + } + Err(err) => { + yield Ok(Event::default().data(err.to_string())); + } + } + }; Sse::new(stream).keep_alive(KeepAlive::default()) }