feat(trtllm): add stop sequence support

Support per request stop sequences.
This commit is contained in:
Tzu-Yu Lee 2025-05-18 02:37:19 +08:00
parent 0858af206f
commit 27d03309c9

View File

@ -35,6 +35,9 @@ struct GenerationContext {
tokens: Vec<u32>,
start: Option<Instant>,
queued: Instant,
/// output_buffer stores the output for detecting stop sequences
output_buffer: Option<String>,
}
#[derive(Debug, Copy, Clone)]
@ -191,11 +194,39 @@ fn executor_status_looper(
fn post_process_decoded_token(
tokenizer: &Tokenizer,
ctx: &mut GenerationContext,
decoded_token: DecodedToken,
mut decoded_token: DecodedToken,
) -> InferResult<InferStreamResponse> {
match tokenizer.decode(&[decoded_token.id], false) {
Ok(text) => {
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
if let Some(buf) = ctx.output_buffer.as_mut() {
if buf.len() + text.len() > buf.capacity() {
let mut start = buf.len() + text.len() - buf.capacity();
while start <= buf.len() && !buf.is_char_boundary(start) {
start += 1;
}
buf.drain(..start);
}
buf.push_str(&text);
for stop_seq in &ctx.request.stopping_parameters.stop_sequences {
let start = if 1 + buf.len() > text.len() + stop_seq.len() {
let mut start = 1 + buf.len() - text.len() - stop_seq.len();
while start > 0 && !buf.is_char_boundary(start) {
start -= 1;
}
start
} else {
0
};
if buf[start..].contains(stop_seq) {
decoded_token.is_final = true;
decoded_token.finish_reason = FinishReason::StopWords;
}
}
}
let token = Token {
id: decoded_token.id,
text,
@ -344,12 +375,20 @@ impl Backend for TensorRtLlmBackendV2 {
// Send the context to the executor for scheduling
let queued = Instant::now();
let output_buffer = request
.stopping_parameters
.stop_sequences
.iter()
.map(|x| x.len())
.max()
.map(|m| String::with_capacity(m + 32)); // TODO: is this number enough?
match self.0.send(GenerationContext {
request,
streamer,
tokens: Vec::with_capacity(256),
start: None,
queued,
output_buffer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(