refactor Stream impl for Generation to factorise code

This commit is contained in:
Morgan Funtowicz 2024-07-18 14:21:43 +00:00
parent b56c43ec30
commit fd021e5461

View File

@ -1,21 +1,22 @@
use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::sync::Arc;
use std::str::FromStr;
use std::sync::{Arc, OnceLock};
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use cxx::UniquePtr;
use log::{debug, info, warn};
use log::{debug, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock;
use tokio::time::{Instant, sleep};
use tokio_stream::{Stream, StreamExt};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{instrument, Level, span, Span};
use tracing::{instrument, Level, span};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
@ -25,6 +26,9 @@ use text_generation_router::validation::ValidationError::UnsupportedModality;
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
type InferResult<T> = Result<T, InferError>;
pub(crate) struct Generation {
@ -47,38 +51,34 @@ impl Stream for Generation {
type Item = usize;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done.load(Ordering::Relaxed) {
Poll::Ready(None)
} else {
let pinned = pin!(self.executor.read());
match pinned.poll(ctx) {
let interval = POLLING_INTERVAL_US.get_or_init(|| {
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
});
if !self.done.load(Ordering::Relaxed) {
let backend = pin!(self.executor.read());
let status = match backend.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(10)).await;
waker.wake();
});
Poll::Pending
} else {
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(100)).await;
waker.wake();
});
Poll::Ready(Some(ready))
}
}
Poll::Pending => {
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(100)).await;
waker.wake();
});
Poll::Pending
}
}
Poll::Pending => Poll::Pending,
};
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_micros(*interval)).await;
waker.wake();
});
status
} else {
Poll::Ready(None) // end of stream
}
}