This commit is contained in:
Yannic Kilcher 2023-01-26 14:57:49 +01:00 committed by OlivierDehaene
parent 033d2174fd
commit a8ddf45c11
3 changed files with 70 additions and 62 deletions

View File

@ -23,8 +23,14 @@ serde = "1.0.145"
serde_json = "1.0.85" serde_json = "1.0.85"
thiserror = "1.0.37" thiserror = "1.0.37"
tokenizers = "0.13.0" tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] } tokio = { version = "1.21.1", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
"net",
] }
tracing = "0.1.36" tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] } tracing-subscriber = { version = "0.3.15", features = ["json"] }
async-stream = "0.3.3" async-stream = "0.3.3"

View File

@ -8,9 +8,9 @@ use nohash_hasher::IntMap;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate}; use text_generation_client::{Batch, ClientError, GeneratedText, Intermediate, ShardedClient};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{oneshot, Notify, mpsc}; use tokio::sync::{mpsc, oneshot, Notify};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
@ -74,13 +74,6 @@ impl Batcher {
// Notify the background task that we have a new entry in the database that needs // Notify the background task that we have a new entry in the database that needs
// to be batched // to be batched
self.shared.batching_task.notify_one(); self.shared.batching_task.notify_one();
// // Await on the response from the background task
// // We can safely unwrap as the background task will never drop the sender
// response_rx
// .await
// .unwrap()
// .map_err(|err| InferError::GenerationError(err.to_string()))
} }
/// Add a new request to the database and return a future that will generate the text /// Add a new request to the database and return a future that will generate the text
@ -217,9 +210,9 @@ fn send_generated(finished: Vec<GeneratedText>, intermediates: Vec<Intermediate>
if let Some(tx) = &entry.intermediate_tx { if let Some(tx) = &entry.intermediate_tx {
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
tx.send(Ok(Some(intermediate))).unwrap_or(()); tx.send(Ok(Some(intermediate))).unwrap_or(());
} }
}); });
finished.into_iter().for_each(|output| { finished.into_iter().for_each(|output| {

View File

@ -8,17 +8,17 @@ use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ShardedClient, IntermediateEvent}; use text_generation_client::{IntermediateEvent, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
use tokio::sync::{oneshot, mpsc};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
use std::convert::Infallible;
use futures::stream::Stream; use futures::stream::Stream;
use std::convert::Infallible;
// Server shared state // Server shared state
#[derive(Clone)] #[derive(Clone)]
@ -81,62 +81,71 @@ async fn generate_stream(
let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel(); let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = oneshot::channel(); let (response_tx, response_rx) = oneshot::channel();
let (input_length, validated_request) = let (input_length, validated_request) = state
state.validation.validate(req.0).await.map_err(|err| { .validation
.validate(req.0)
.await
.map_err(|err| {
tracing::error!("{}", err.to_string()); tracing::error!("{}", err.to_string());
err err
}).unwrap(); })
.unwrap();
// Inference // Inference
state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx); state.batcher.infer_stream(
input_length,
validated_request,
intermediate_tx,
response_tx,
);
let stream = async_stream::stream! { let stream = async_stream::stream! {
while let Some(item) = intermediate_rx.recv().await { while let Some(item) = intermediate_rx.recv().await {
match item { match item {
Ok(item) => { Ok(item) => {
match item { match item {
Some(item) => { Some(item) => {
let event_data = IntermediateEvent { let event_data = IntermediateEvent {
token: item.token, token: item.token,
token_id: item.token_id, token_id: item.token_id,
logprob: item.logprob, logprob: item.logprob,
}; };
let stream_event = StreamEvent { let stream_event = StreamEvent {
is_end: false, is_end: false,
event: Some(event_data), event: Some(event_data),
generated_text: None, generated_text: None,
}; };
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
} }
None => { None => {
break break
}
} }
} }
Err(err) => {
yield Ok(Event::default().data(err.to_string()));
}
}
}
let response = response_rx.await.unwrap();
match response {
Ok(response) => {
let response = GeneratedText {
generated_text: response.output_text,
details: None,
};
let stream_event = StreamEvent {
is_end: true,
event: None,
generated_text: Some(response),
};
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
} }
Err(err) => { Err(err) => {
yield Ok(Event::default().data(err.to_string())); yield Ok(Event::default().data(err.to_string()));
} }
} }
}; }
let response = response_rx.await.unwrap();
match response {
Ok(response) => {
let response = GeneratedText {
generated_text: response.output_text,
details: None,
};
let stream_event = StreamEvent {
is_end: true,
event: None,
generated_text: Some(response),
};
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
}
Err(err) => {
yield Ok(Event::default().data(err.to_string()));
}
}
};
Sse::new(stream).keep_alive(KeepAlive::default()) Sse::new(stream).keep_alive(KeepAlive::default())
} }