diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index fd0bc967..6d7f30c3 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -35,6 +35,9 @@ struct GenerationContext { tokens: Vec, start: Option, queued: Instant, + + /// output_buffer stores the output for detecting stop sequences + output_buffer: Option, } #[derive(Debug, Copy, Clone)] @@ -191,11 +194,39 @@ fn executor_status_looper( fn post_process_decoded_token( tokenizer: &Tokenizer, ctx: &mut GenerationContext, - decoded_token: DecodedToken, + mut decoded_token: DecodedToken, ) -> InferResult { match tokenizer.decode(&[decoded_token.id], false) { Ok(text) => { let is_special = tokenizer.get_added_vocabulary().is_special_token(&text); + + if let Some(buf) = ctx.output_buffer.as_mut() { + if buf.len() + text.len() > buf.capacity() { + let mut start = buf.len() + text.len() - buf.capacity(); + while start <= buf.len() && !buf.is_char_boundary(start) { + start += 1; + } + buf.drain(..start); + } + buf.push_str(&text); + + for stop_seq in &ctx.request.stopping_parameters.stop_sequences { + let start = if 1 + buf.len() > text.len() + stop_seq.len() { + let mut start = 1 + buf.len() - text.len() - stop_seq.len(); + while start > 0 && !buf.is_char_boundary(start) { + start -= 1; + } + start + } else { + 0 + }; + if buf[start..].contains(stop_seq) { + decoded_token.is_final = true; + decoded_token.finish_reason = FinishReason::StopWords; + } + } + } + let token = Token { id: decoded_token.id, text, @@ -344,12 +375,20 @@ impl Backend for TensorRtLlmBackendV2 { // Send the context to the executor for scheduling let queued = Instant::now(); + let output_buffer = request + .stopping_parameters + .stop_sequences + .iter() + .map(|x| x.len()) + .max() + .map(|m| String::with_capacity(m + 32)); // TODO: is this number enough? match self.0.send(GenerationContext { request, streamer, tokens: Vec::with_capacity(256), start: None, queued, + output_buffer, }) { Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Err(_) => Err(GenerationError(