impl the rust backend which currently cannot move the actual computation in background thread

This commit is contained in:
Morgan Funtowicz 2024-07-12 19:26:32 +00:00
parent 518d9a9e0b
commit b291be64a0

View File

@ -8,13 +8,16 @@ use tokio::sync::mpsc;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::{FinishReason, Token};
use text_generation_router::validation::{Chunk, ValidGenerateRequest}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters};
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackendImpl}; use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>); type InferResult<T> = Result<T, InferError>;
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
pub struct TrtLLmBackend { pub struct TrtLLmBackend {
tokenizer: Tokenizer, tokenizer: Tokenizer,
@ -30,42 +33,32 @@ impl TrtLLmBackend {
engine_folder: P, engine_folder: P,
) -> Result<Self, TensorRtLlmBackendError> { ) -> Result<Self, TensorRtLlmBackendError> {
let engine_folder = engine_folder.as_ref(); let engine_folder = engine_folder.as_ref();
let inner = create_trtllm_backend(engine_folder.to_str().unwrap(), ""); let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
Ok(Self { Ok(Self {
tokenizer, tokenizer,
inner: RefCell::new(inner), inner: RefCell::new(inner),
}) })
} }
}
#[async_trait] fn infer_text(
impl Backend for TrtLLmBackend {
fn schedule(
&self, &self,
request: ValidGenerateRequest, ctx: GenerationContext,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> { text: &str,
let (sender, receiver) = mpsc::unbounded_channel(); params: ValidParameters,
let ctx = Box::new(GenerationContext(sender)); ) -> InferResult<()> {
// Keep track of processing time
let start = Instant::now();
// Unpack parameters // Encode the input
let params = request.parameters; let ctx = Box::new(ctx);
// Currently we handle single chunk of text
if request.inputs.len() == 1 {
match request
.inputs
.first()
.expect("Failed to access the first chunk")
{
Chunk::Text(text) => {
let encoding = self let encoding = self
.tokenizer .tokenizer
.encode(&**text, true) .encode(text, true)
.map_err(|e| InferError::ToolError(e.to_string()))?; .map_err(|e| InferError::ToolError(e.to_string()))?;
let _start = Instant::now(); // Submit the request to the backend and retrieve the handle to query its status
let _request_id = self let request_id = self
.inner .inner
.borrow_mut() .borrow_mut()
.as_mut() .as_mut()
@ -79,52 +72,83 @@ impl Backend for TrtLLmBackend {
params.seed, params.seed,
); );
// spawn_blocking(|| { // Stream generated tokens
// // Stream generated tokens // spawn_blocking(move || {
// let num_generated_tokens = self let num_generated_tokens = self
// .inner .inner
// .borrow_mut() .borrow_mut()
// .as_mut() .as_mut()
// .expect("Failed to retrieve pointer to TRTLLM backend") .expect("Failed to retrieve pointer to TRTLLM backend")
// .stream(request_id, ctx, |token, step, is_final| { .stream(ctx, request_id, |ctx, token, step, is_final| {
// // self.tokenizer.decode(&*[token], true).unwrap(); // self.tokenizer.decode(&*[token], true).unwrap();
// let token = Token { let sender = ctx.0;
// id: token, let token = Token {
// text: String::from(""), id: token,
// logprob: 1.0f32, text: String::from(""),
// special: false, logprob: 1.0f32,
// }; special: false,
// };
// sender
// .send(Ok(InferStreamResponse::Intermediate { sender
// token, .send(Ok(InferStreamResponse::Intermediate {
// top_tokens: vec![], token,
// })) top_tokens: vec![],
// .unwrap() }))
// }); .unwrap()
// });
// // Notify the end
// Ok(InferStreamResponse::End { // Notify the end
// token: Token { let _ = ctx.0.send(Ok(InferStreamResponse::End {
// id: 0, token: Token {
// text: String::from(""), id: 0,
// logprob: 1.0f32, text: String::from(""),
// special: false, logprob: 1.0f32,
// }, special: false,
// top_tokens: vec![], },
// generated_text: GeneratedText { top_tokens: vec![],
// text: String::from(""), generated_text: GeneratedText {
// generated_tokens: num_generated_tokens, text: String::from(""),
// finish_reason: FinishReason::EndOfSequenceToken, generated_tokens: num_generated_tokens,
// seed: Some(params.seed), finish_reason: FinishReason::EndOfSequenceToken,
// }, seed: Some(params.seed),
// start, },
// queued: Instant::now(), start,
// }) queued: Instant::now(),
}));
// }); // });
Ok(())
} }
Chunk::Image(_) => {} }
#[async_trait]
impl Backend for TrtLLmBackend {
fn schedule(
&self,
request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
let (sender, receiver) = mpsc::unbounded_channel();
let ctx = GenerationContext(sender);
// Unpack parameters
let params = request.parameters;
// Ensure we are running in the right conditions for the input (i.e. single textual chunk)
let input = match request.inputs.len() {
0 => Err(InferError::GenerationError("No input provided".into())),
1 => Ok(request.inputs.first().unwrap()),
_ => Err(InferError::GenerationError(format!(
"Unsupported multi-chunks ({}) inference.",
request.inputs.len()
))),
}?;
// Currently we handle single chunk of text
match input {
Chunk::Text(text) => {
self.infer_text(ctx, &**text, params)?;
} }
Chunk::Image(_) => panic!("Unsupported"),
}; };
Ok(UnboundedReceiverStream::new(receiver)) Ok(UnboundedReceiverStream::new(receiver))