mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-30 03:40:17 +00:00
feat(trtllm): add stop sequence support
Support per request stop sequences.
This commit is contained in:
parent
0858af206f
commit
27d03309c9
@ -35,6 +35,9 @@ struct GenerationContext {
|
||||
tokens: Vec<u32>,
|
||||
start: Option<Instant>,
|
||||
queued: Instant,
|
||||
|
||||
/// output_buffer stores the output for detecting stop sequences
|
||||
output_buffer: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@ -191,11 +194,39 @@ fn executor_status_looper(
|
||||
fn post_process_decoded_token(
|
||||
tokenizer: &Tokenizer,
|
||||
ctx: &mut GenerationContext,
|
||||
decoded_token: DecodedToken,
|
||||
mut decoded_token: DecodedToken,
|
||||
) -> InferResult<InferStreamResponse> {
|
||||
match tokenizer.decode(&[decoded_token.id], false) {
|
||||
Ok(text) => {
|
||||
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||
|
||||
if let Some(buf) = ctx.output_buffer.as_mut() {
|
||||
if buf.len() + text.len() > buf.capacity() {
|
||||
let mut start = buf.len() + text.len() - buf.capacity();
|
||||
while start <= buf.len() && !buf.is_char_boundary(start) {
|
||||
start += 1;
|
||||
}
|
||||
buf.drain(..start);
|
||||
}
|
||||
buf.push_str(&text);
|
||||
|
||||
for stop_seq in &ctx.request.stopping_parameters.stop_sequences {
|
||||
let start = if 1 + buf.len() > text.len() + stop_seq.len() {
|
||||
let mut start = 1 + buf.len() - text.len() - stop_seq.len();
|
||||
while start > 0 && !buf.is_char_boundary(start) {
|
||||
start -= 1;
|
||||
}
|
||||
start
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if buf[start..].contains(stop_seq) {
|
||||
decoded_token.is_final = true;
|
||||
decoded_token.finish_reason = FinishReason::StopWords;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let token = Token {
|
||||
id: decoded_token.id,
|
||||
text,
|
||||
@ -344,12 +375,20 @@ impl Backend for TensorRtLlmBackendV2 {
|
||||
|
||||
// Send the context to the executor for scheduling
|
||||
let queued = Instant::now();
|
||||
let output_buffer = request
|
||||
.stopping_parameters
|
||||
.stop_sequences
|
||||
.iter()
|
||||
.map(|x| x.len())
|
||||
.max()
|
||||
.map(|m| String::with_capacity(m + 32)); // TODO: is this number enough?
|
||||
match self.0.send(GenerationContext {
|
||||
request,
|
||||
streamer,
|
||||
tokens: Vec::with_capacity(256),
|
||||
start: None,
|
||||
queued,
|
||||
output_buffer,
|
||||
}) {
|
||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||
Err(_) => Err(GenerationError(
|
||||
|
Loading…
Reference in New Issue
Block a user