mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
refactor Stream impl for Generation to factorise code
This commit is contained in:
parent
b56c43ec30
commit
fd021e5461
@ -1,21 +1,22 @@
|
||||
use std::future::Future;
|
||||
use std::path::Path;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::Arc;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use log::{debug, info, warn};
|
||||
use log::{debug, warn};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::{Instant, sleep};
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{instrument, Level, span, Span};
|
||||
use tracing::{instrument, Level, span};
|
||||
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
@ -25,6 +26,9 @@ use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||
|
||||
// Value used to poll the state of the generation stream
|
||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
pub(crate) struct Generation {
|
||||
@ -47,38 +51,34 @@ impl Stream for Generation {
|
||||
type Item = usize;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if self.done.load(Ordering::Relaxed) {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
let pinned = pin!(self.executor.read());
|
||||
match pinned.poll(ctx) {
|
||||
let interval = POLLING_INTERVAL_US.get_or_init(|| {
|
||||
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
|
||||
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
|
||||
});
|
||||
|
||||
if !self.done.load(Ordering::Relaxed) {
|
||||
let backend = pin!(self.executor.read());
|
||||
let status = match backend.poll(ctx) {
|
||||
Poll::Ready(executor_r) => {
|
||||
let ready = executor_r.num_responses_ready();
|
||||
if ready == 0 {
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
waker.wake();
|
||||
});
|
||||
Poll::Pending
|
||||
} else {
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
waker.wake();
|
||||
});
|
||||
Poll::Ready(Some(ready))
|
||||
}
|
||||
}
|
||||
Poll::Pending => {
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
waker.wake();
|
||||
});
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
};
|
||||
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_micros(*interval)).await;
|
||||
waker.wake();
|
||||
});
|
||||
|
||||
status
|
||||
} else {
|
||||
Poll::Ready(None) // end of stream
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user