mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
refactor Stream impl for Generation to factorise code
This commit is contained in:
parent
b56c43ec30
commit
fd021e5461
@ -1,21 +1,22 @@
|
|||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::pin::{pin, Pin};
|
use std::pin::{pin, Pin};
|
||||||
use std::sync::Arc;
|
use std::str::FromStr;
|
||||||
|
use std::sync::{Arc, OnceLock};
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use log::{debug, info, warn};
|
use log::{debug, warn};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tokio::time::{Instant, sleep};
|
use tokio::time::{Instant, sleep};
|
||||||
use tokio_stream::{Stream, StreamExt};
|
use tokio_stream::{Stream, StreamExt};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{instrument, Level, span, Span};
|
use tracing::{instrument, Level, span};
|
||||||
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
@ -25,6 +26,9 @@ use text_generation_router::validation::ValidationError::UnsupportedModality;
|
|||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||||
|
|
||||||
|
// Value used to poll the state of the generation stream
|
||||||
|
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||||
|
|
||||||
type InferResult<T> = Result<T, InferError>;
|
type InferResult<T> = Result<T, InferError>;
|
||||||
|
|
||||||
pub(crate) struct Generation {
|
pub(crate) struct Generation {
|
||||||
@ -47,38 +51,34 @@ impl Stream for Generation {
|
|||||||
type Item = usize;
|
type Item = usize;
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
if self.done.load(Ordering::Relaxed) {
|
let interval = POLLING_INTERVAL_US.get_or_init(|| {
|
||||||
Poll::Ready(None)
|
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
|
||||||
} else {
|
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
|
||||||
let pinned = pin!(self.executor.read());
|
});
|
||||||
match pinned.poll(ctx) {
|
|
||||||
|
if !self.done.load(Ordering::Relaxed) {
|
||||||
|
let backend = pin!(self.executor.read());
|
||||||
|
let status = match backend.poll(ctx) {
|
||||||
Poll::Ready(executor_r) => {
|
Poll::Ready(executor_r) => {
|
||||||
let ready = executor_r.num_responses_ready();
|
let ready = executor_r.num_responses_ready();
|
||||||
if ready == 0 {
|
if ready == 0 {
|
||||||
let waker = ctx.waker().clone();
|
|
||||||
tokio::spawn(async {
|
|
||||||
sleep(Duration::from_millis(10)).await;
|
|
||||||
waker.wake();
|
|
||||||
});
|
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
} else {
|
} else {
|
||||||
let waker = ctx.waker().clone();
|
|
||||||
tokio::spawn(async {
|
|
||||||
sleep(Duration::from_millis(100)).await;
|
|
||||||
waker.wake();
|
|
||||||
});
|
|
||||||
Poll::Ready(Some(ready))
|
Poll::Ready(Some(ready))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Poll::Pending => {
|
Poll::Pending => Poll::Pending,
|
||||||
|
};
|
||||||
|
|
||||||
let waker = ctx.waker().clone();
|
let waker = ctx.waker().clone();
|
||||||
tokio::spawn(async {
|
tokio::spawn(async {
|
||||||
sleep(Duration::from_millis(100)).await;
|
sleep(Duration::from_micros(*interval)).await;
|
||||||
waker.wake();
|
waker.wake();
|
||||||
});
|
});
|
||||||
Poll::Pending
|
|
||||||
}
|
status
|
||||||
}
|
} else {
|
||||||
|
Poll::Ready(None) // end of stream
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user