From 8d1c3c8ad445c540fad2f642be127f3c85bdb96c Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 21 Oct 2024 15:06:54 +0200 Subject: [PATCH] feat(trtllm): do not tokenize twice --- backends/trtllm/src/looper.rs | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 1411a8ea..d97fa69f 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -33,15 +33,9 @@ struct IdentifiableRequest { inner: T, } -/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt -struct ValidGenerateRequestWithTokens { - encoding: Encoding, - inner: ValidGenerateRequest, -} - /// Wrap the requests along with the channel used to stream back to the client the decoded tokens struct GenerationContext { - request: ValidGenerateRequestWithTokens, + request: ValidGenerateRequest, start: Option, queued: Instant, streamer: UnboundedSender>, @@ -97,12 +91,13 @@ fn executor_status_looper( if let Some(mut ctx) = waiting_requests.blocking_recv() { // Submit all the request to the executor and move the context to the in-flight tracker let request = &ctx.request; - let generation_params = &request.inner.parameters; - let stopping_params = &request.inner.stopping_parameters; + let generation_params = &request.parameters; + let stopping_params = &request.stopping_parameters; + let input_ids = request.input_ids.as_deref(); // Submit to the TensorRT-LLM executor for scheduling match backend.pin_mut().submit( - request.encoding.get_ids(), + &input_ids.unwrap(), // This is checked beforehand in validate() stopping_params.max_new_tokens, generation_params.top_k as i32, generation_params.top_p, @@ -343,7 +338,11 @@ impl TensorRtLlmBackendV2 { }) } - fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { + fn validate(request: &ValidGenerateRequest) -> InferResult<()> { + if request.input_ids.is_none() { + return Err(ValidationError(UnsupportedModality("No token provided"))); + } + if request.top_n_tokens > 1 { return Err(ValidationError(TopNTokensDisabled)); } @@ -359,7 +358,7 @@ impl TensorRtLlmBackendV2 { "TensorRT-LLM backend don't support multi-chunk".into(), )), 1 => match request.inputs.first().expect("Single item-chunk") { - Chunk::Text(text) => Ok(text), + Chunk::Text(text) => Ok(()), Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))), }, } @@ -372,15 +371,7 @@ impl Backend for TensorRtLlmBackendV2 { &self, inner: ValidGenerateRequest, ) -> Result>, InferError> { - let prompt = Self::validate(&inner)?; - - // We encode the prompt in every request context/thread - let encoding = self - .tokenizer - .encode(prompt.as_str(), true) - .map_err(|e| GenerationError(format!("Tokenization failed {}", e.to_string())))?; - - let request = ValidGenerateRequestWithTokens { encoding, inner }; + Self::validate(&inner)?; // Open-up the stream to send tokens let (streamer, receiver) = unbounded_channel::>(); @@ -388,7 +379,7 @@ impl Backend for TensorRtLlmBackendV2 { // Send the context to the executor for scheduling let queued = Instant::now(); match self.executor.send(GenerationContext { - request, + request: inner, start: None, queued, streamer,