mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-31 04:10:16 +00:00
feat(trtllm): add stop sequence support
Support per request stop sequences.
This commit is contained in:
parent
0858af206f
commit
27d03309c9
@ -35,6 +35,9 @@ struct GenerationContext {
|
|||||||
tokens: Vec<u32>,
|
tokens: Vec<u32>,
|
||||||
start: Option<Instant>,
|
start: Option<Instant>,
|
||||||
queued: Instant,
|
queued: Instant,
|
||||||
|
|
||||||
|
/// output_buffer stores the output for detecting stop sequences
|
||||||
|
output_buffer: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@ -191,11 +194,39 @@ fn executor_status_looper(
|
|||||||
fn post_process_decoded_token(
|
fn post_process_decoded_token(
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
ctx: &mut GenerationContext,
|
ctx: &mut GenerationContext,
|
||||||
decoded_token: DecodedToken,
|
mut decoded_token: DecodedToken,
|
||||||
) -> InferResult<InferStreamResponse> {
|
) -> InferResult<InferStreamResponse> {
|
||||||
match tokenizer.decode(&[decoded_token.id], false) {
|
match tokenizer.decode(&[decoded_token.id], false) {
|
||||||
Ok(text) => {
|
Ok(text) => {
|
||||||
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
|
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
|
|
||||||
|
if let Some(buf) = ctx.output_buffer.as_mut() {
|
||||||
|
if buf.len() + text.len() > buf.capacity() {
|
||||||
|
let mut start = buf.len() + text.len() - buf.capacity();
|
||||||
|
while start <= buf.len() && !buf.is_char_boundary(start) {
|
||||||
|
start += 1;
|
||||||
|
}
|
||||||
|
buf.drain(..start);
|
||||||
|
}
|
||||||
|
buf.push_str(&text);
|
||||||
|
|
||||||
|
for stop_seq in &ctx.request.stopping_parameters.stop_sequences {
|
||||||
|
let start = if 1 + buf.len() > text.len() + stop_seq.len() {
|
||||||
|
let mut start = 1 + buf.len() - text.len() - stop_seq.len();
|
||||||
|
while start > 0 && !buf.is_char_boundary(start) {
|
||||||
|
start -= 1;
|
||||||
|
}
|
||||||
|
start
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
if buf[start..].contains(stop_seq) {
|
||||||
|
decoded_token.is_final = true;
|
||||||
|
decoded_token.finish_reason = FinishReason::StopWords;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let token = Token {
|
let token = Token {
|
||||||
id: decoded_token.id,
|
id: decoded_token.id,
|
||||||
text,
|
text,
|
||||||
@ -344,12 +375,20 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
|
|
||||||
// Send the context to the executor for scheduling
|
// Send the context to the executor for scheduling
|
||||||
let queued = Instant::now();
|
let queued = Instant::now();
|
||||||
|
let output_buffer = request
|
||||||
|
.stopping_parameters
|
||||||
|
.stop_sequences
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.len())
|
||||||
|
.max()
|
||||||
|
.map(|m| String::with_capacity(m + 32)); // TODO: is this number enough?
|
||||||
match self.0.send(GenerationContext {
|
match self.0.send(GenerationContext {
|
||||||
request,
|
request,
|
||||||
streamer,
|
streamer,
|
||||||
tokens: Vec::with_capacity(256),
|
tokens: Vec::with_capacity(256),
|
||||||
start: None,
|
start: None,
|
||||||
queued,
|
queued,
|
||||||
|
output_buffer,
|
||||||
}) {
|
}) {
|
||||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||||
Err(_) => Err(GenerationError(
|
Err(_) => Err(GenerationError(
|
||||||
|
Loading…
Reference in New Issue
Block a user