diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index 296c7373..5fcc3d33 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -1,21 +1,22 @@ use std::future::Future; use std::path::Path; use std::pin::{pin, Pin}; -use std::sync::Arc; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; use std::sync::atomic::{AtomicBool, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; use async_trait::async_trait; use cxx::UniquePtr; -use log::{debug, info, warn}; +use log::{debug, warn}; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::RwLock; use tokio::time::{Instant, sleep}; use tokio_stream::{Stream, StreamExt}; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{instrument, Level, span, Span}; +use tracing::{instrument, Level, span}; use text_generation_router::{FinishReason, Token}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; @@ -25,6 +26,9 @@ use text_generation_router::validation::ValidationError::UnsupportedModality; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; +// Value used to poll the state of the generation stream +static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); + type InferResult = Result; pub(crate) struct Generation { @@ -47,38 +51,34 @@ impl Stream for Generation { type Item = usize; fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - if self.done.load(Ordering::Relaxed) { - Poll::Ready(None) - } else { - let pinned = pin!(self.executor.read()); - match pinned.poll(ctx) { + let interval = POLLING_INTERVAL_US.get_or_init(|| { + u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100")) + .expect("Invalid value provided for envvar POLLING_INTERVAL_US") + }); + + if !self.done.load(Ordering::Relaxed) { + let backend = pin!(self.executor.read()); + let status = match backend.poll(ctx) { Poll::Ready(executor_r) => { let ready = executor_r.num_responses_ready(); if ready == 0 { - let waker = ctx.waker().clone(); - tokio::spawn(async { - sleep(Duration::from_millis(10)).await; - waker.wake(); - }); Poll::Pending } else { - let waker = ctx.waker().clone(); - tokio::spawn(async { - sleep(Duration::from_millis(100)).await; - waker.wake(); - }); Poll::Ready(Some(ready)) } } - Poll::Pending => { - let waker = ctx.waker().clone(); - tokio::spawn(async { - sleep(Duration::from_millis(100)).await; - waker.wake(); - }); - Poll::Pending - } - } + Poll::Pending => Poll::Pending, + }; + + let waker = ctx.waker().clone(); + tokio::spawn(async { + sleep(Duration::from_micros(*interval)).await; + waker.wake(); + }); + + status + } else { + Poll::Ready(None) // end of stream } }