mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
cleanup
This commit is contained in:
parent
033d2174fd
commit
a8ddf45c11
@ -23,8 +23,14 @@ serde = "1.0.145"
|
|||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokenizers = "0.13.0"
|
tokenizers = "0.13.0"
|
||||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] }
|
tokio = { version = "1.21.1", features = [
|
||||||
|
"rt",
|
||||||
|
"rt-multi-thread",
|
||||||
|
"parking_lot",
|
||||||
|
"signal",
|
||||||
|
"sync",
|
||||||
|
"net",
|
||||||
|
] }
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||||
async-stream = "0.3.3"
|
async-stream = "0.3.3"
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ use nohash_hasher::IntMap;
|
|||||||
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate};
|
use text_generation_client::{Batch, ClientError, GeneratedText, Intermediate, ShardedClient};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{oneshot, Notify, mpsc};
|
use tokio::sync::{mpsc, oneshot, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -74,13 +74,6 @@ impl Batcher {
|
|||||||
// Notify the background task that we have a new entry in the database that needs
|
// Notify the background task that we have a new entry in the database that needs
|
||||||
// to be batched
|
// to be batched
|
||||||
self.shared.batching_task.notify_one();
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
// // Await on the response from the background task
|
|
||||||
// // We can safely unwrap as the background task will never drop the sender
|
|
||||||
// response_rx
|
|
||||||
// .await
|
|
||||||
// .unwrap()
|
|
||||||
// .map_err(|err| InferError::GenerationError(err.to_string()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the database and return a future that will generate the text
|
/// Add a new request to the database and return a future that will generate the text
|
||||||
@ -217,9 +210,9 @@ fn send_generated(finished: Vec<GeneratedText>, intermediates: Vec<Intermediate>
|
|||||||
|
|
||||||
|
|
||||||
if let Some(tx) = &entry.intermediate_tx {
|
if let Some(tx) = &entry.intermediate_tx {
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
tx.send(Ok(Some(intermediate))).unwrap_or(());
|
tx.send(Ok(Some(intermediate))).unwrap_or(());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
finished.into_iter().for_each(|output| {
|
finished.into_iter().for_each(|output| {
|
||||||
|
@ -8,17 +8,17 @@ use axum::routing::{get, post};
|
|||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardedClient, IntermediateEvent};
|
use text_generation_client::{IntermediateEvent, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use tokio::sync::{oneshot, mpsc};
|
|
||||||
|
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use std::convert::Infallible;
|
|
||||||
use futures::stream::Stream;
|
use futures::stream::Stream;
|
||||||
|
use std::convert::Infallible;
|
||||||
|
|
||||||
// Server shared state
|
// Server shared state
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -81,62 +81,71 @@ async fn generate_stream(
|
|||||||
let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel();
|
let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel();
|
||||||
let (response_tx, response_rx) = oneshot::channel();
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
|
|
||||||
let (input_length, validated_request) =
|
let (input_length, validated_request) = state
|
||||||
state.validation.validate(req.0).await.map_err(|err| {
|
.validation
|
||||||
|
.validate(req.0)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
tracing::error!("{}", err.to_string());
|
tracing::error!("{}", err.to_string());
|
||||||
err
|
err
|
||||||
}).unwrap();
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx);
|
state.batcher.infer_stream(
|
||||||
|
input_length,
|
||||||
|
validated_request,
|
||||||
|
intermediate_tx,
|
||||||
|
response_tx,
|
||||||
|
);
|
||||||
|
|
||||||
let stream = async_stream::stream! {
|
let stream = async_stream::stream! {
|
||||||
while let Some(item) = intermediate_rx.recv().await {
|
while let Some(item) = intermediate_rx.recv().await {
|
||||||
match item {
|
match item {
|
||||||
Ok(item) => {
|
Ok(item) => {
|
||||||
match item {
|
match item {
|
||||||
Some(item) => {
|
Some(item) => {
|
||||||
let event_data = IntermediateEvent {
|
let event_data = IntermediateEvent {
|
||||||
token: item.token,
|
token: item.token,
|
||||||
token_id: item.token_id,
|
token_id: item.token_id,
|
||||||
logprob: item.logprob,
|
logprob: item.logprob,
|
||||||
};
|
};
|
||||||
let stream_event = StreamEvent {
|
let stream_event = StreamEvent {
|
||||||
is_end: false,
|
is_end: false,
|
||||||
event: Some(event_data),
|
event: Some(event_data),
|
||||||
generated_text: None,
|
generated_text: None,
|
||||||
};
|
};
|
||||||
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
|
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
break
|
break
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => {
|
|
||||||
yield Ok(Event::default().data(err.to_string()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let response = response_rx.await.unwrap();
|
|
||||||
match response {
|
|
||||||
Ok(response) => {
|
|
||||||
let response = GeneratedText {
|
|
||||||
generated_text: response.output_text,
|
|
||||||
details: None,
|
|
||||||
};
|
|
||||||
let stream_event = StreamEvent {
|
|
||||||
is_end: true,
|
|
||||||
event: None,
|
|
||||||
generated_text: Some(response),
|
|
||||||
};
|
|
||||||
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
|
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
yield Ok(Event::default().data(err.to_string()));
|
yield Ok(Event::default().data(err.to_string()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
let response = response_rx.await.unwrap();
|
||||||
|
match response {
|
||||||
|
Ok(response) => {
|
||||||
|
let response = GeneratedText {
|
||||||
|
generated_text: response.output_text,
|
||||||
|
details: None,
|
||||||
|
};
|
||||||
|
let stream_event = StreamEvent {
|
||||||
|
is_end: true,
|
||||||
|
event: None,
|
||||||
|
generated_text: Some(response),
|
||||||
|
};
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
yield Ok(Event::default().data(err.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user