From b291be64a036c12245a0811d58e0efece686b2d4 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 12 Jul 2024 19:26:32 +0000 Subject: [PATCH] impl the rust backend which currently cannot move the actual computation in background thread --- backends/trtllm/src/backend.rs | 182 +++++++++++++++++++-------------- 1 file changed, 103 insertions(+), 79 deletions(-) diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index d4d3d00c..eec4e081 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -8,13 +8,16 @@ use tokio::sync::mpsc; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; -use text_generation_router::validation::{Chunk, ValidGenerateRequest}; +use text_generation_router::{FinishReason, Token}; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters}; use crate::errors::TensorRtLlmBackendError; -use crate::ffi::{create_trtllm_backend, TensorRtLlmBackendImpl}; +use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; -struct GenerationContext(mpsc::UnboundedSender>); +type InferResult = Result; + +pub struct GenerationContext(mpsc::UnboundedSender>); pub struct TrtLLmBackend { tokenizer: Tokenizer, @@ -30,13 +33,92 @@ impl TrtLLmBackend { engine_folder: P, ) -> Result { let engine_folder = engine_folder.as_ref(); - let inner = create_trtllm_backend(engine_folder.to_str().unwrap(), ""); + let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), ""); Ok(Self { tokenizer, inner: RefCell::new(inner), }) } + + fn infer_text( + &self, + ctx: GenerationContext, + text: &str, + params: ValidParameters, + ) -> InferResult<()> { + // Keep track of processing time + let start = Instant::now(); + + // Encode the input + let ctx = Box::new(ctx); + let encoding = self + .tokenizer + .encode(text, true) + .map_err(|e| InferError::ToolError(e.to_string()))?; + + // Submit the request to the backend and retrieve the handle to query its status + let request_id = self + .inner + .borrow_mut() + .as_mut() + .expect("Failed to retrieve pointer to TRTLLM backend") + .submit( + encoding.get_ids(), + 128, + params.top_k as i32, + params.top_p, + params.temperature, + params.seed, + ); + + // Stream generated tokens + // spawn_blocking(move || { + let num_generated_tokens = self + .inner + .borrow_mut() + .as_mut() + .expect("Failed to retrieve pointer to TRTLLM backend") + .stream(ctx, request_id, |ctx, token, step, is_final| { + // self.tokenizer.decode(&*[token], true).unwrap(); + let sender = ctx.0; + let token = Token { + id: token, + text: String::from(""), + logprob: 1.0f32, + special: false, + }; + + sender + .send(Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + })) + .unwrap() + }); + + // Notify the end + let _ = ctx.0.send(Ok(InferStreamResponse::End { + token: Token { + id: 0, + text: String::from(""), + logprob: 1.0f32, + special: false, + }, + top_tokens: vec![], + generated_text: GeneratedText { + text: String::from(""), + generated_tokens: num_generated_tokens, + finish_reason: FinishReason::EndOfSequenceToken, + seed: Some(params.seed), + }, + start, + queued: Instant::now(), + })); + // }); + + Ok(()) + } } #[async_trait] @@ -44,87 +126,29 @@ impl Backend for TrtLLmBackend { fn schedule( &self, request: ValidGenerateRequest, - ) -> Result>, InferError> { + ) -> InferResult>> { let (sender, receiver) = mpsc::unbounded_channel(); - let ctx = Box::new(GenerationContext(sender)); + let ctx = GenerationContext(sender); // Unpack parameters let params = request.parameters; + // Ensure we are running in the right conditions for the input (i.e. single textual chunk) + let input = match request.inputs.len() { + 0 => Err(InferError::GenerationError("No input provided".into())), + 1 => Ok(request.inputs.first().unwrap()), + _ => Err(InferError::GenerationError(format!( + "Unsupported multi-chunks ({}) inference.", + request.inputs.len() + ))), + }?; + // Currently we handle single chunk of text - if request.inputs.len() == 1 { - match request - .inputs - .first() - .expect("Failed to access the first chunk") - { - Chunk::Text(text) => { - let encoding = self - .tokenizer - .encode(&**text, true) - .map_err(|e| InferError::ToolError(e.to_string()))?; - - let _start = Instant::now(); - let _request_id = self - .inner - .borrow_mut() - .as_mut() - .expect("Failed to retrieve pointer to TRTLLM backend") - .submit( - encoding.get_ids(), - 128, - params.top_k as i32, - params.top_p, - params.temperature, - params.seed, - ); - - // spawn_blocking(|| { - // // Stream generated tokens - // let num_generated_tokens = self - // .inner - // .borrow_mut() - // .as_mut() - // .expect("Failed to retrieve pointer to TRTLLM backend") - // .stream(request_id, ctx, |token, step, is_final| { - // // self.tokenizer.decode(&*[token], true).unwrap(); - // let token = Token { - // id: token, - // text: String::from(""), - // logprob: 1.0f32, - // special: false, - // }; - // - // sender - // .send(Ok(InferStreamResponse::Intermediate { - // token, - // top_tokens: vec![], - // })) - // .unwrap() - // }); - // - // // Notify the end - // Ok(InferStreamResponse::End { - // token: Token { - // id: 0, - // text: String::from(""), - // logprob: 1.0f32, - // special: false, - // }, - // top_tokens: vec![], - // generated_text: GeneratedText { - // text: String::from(""), - // generated_tokens: num_generated_tokens, - // finish_reason: FinishReason::EndOfSequenceToken, - // seed: Some(params.seed), - // }, - // start, - // queued: Instant::now(), - // }) - // }); - } - Chunk::Image(_) => {} + match input { + Chunk::Text(text) => { + self.infer_text(ctx, &**text, params)?; } + Chunk::Image(_) => panic!("Unsupported"), }; Ok(UnboundedReceiverStream::new(receiver))