diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 0428a4dc7..487c04577 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -128,7 +128,6 @@ fn executor_status_looper( } } - // info!("Num response ready: {}", backend.num_responses_ready()); if backend.num_responses_ready() > 0 { match backend.pin_mut().pull_tokens() { Ok(responses) => { @@ -172,20 +171,23 @@ fn post_processor_looper( tokenizer: Tokenizer, mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>, ) { + let mut states: HashMap> = HashMap::with_capacity(128); + 'post_processor: loop { if decoded_tokens.is_closed() { warn!("Post processor IPC is closed, loop will exit now."); break 'post_processor; } - let mut states: HashMap> = HashMap::with_capacity(128); - if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { - let state = states.entry(request_id).or_insert(vec![]); - match decoded { Ok(ctx) => { - state.push(ctx.token.id); + states.entry(request_id).and_modify(|s| s.push(*&ctx.token.id)).or_insert_with(|| { + let mut state = Vec::with_capacity(128); + state.push(*&ctx.token.id); + state + }); + let out = match tokenizer.decode(&[ctx.token.id], false) { Ok(text) => { let is_special = @@ -203,10 +205,11 @@ fn post_processor_looper( top_tokens: vec![], } } else { - let text = tokenizer.decode(&state, true); + let tokens = states.remove(&request_id).unwrap(); + let text = tokenizer.decode(&tokens, true); let generated_text = GeneratedText { text: text.unwrap(), - generated_tokens: state.len() as u32, + generated_tokens: tokens.len() as u32, finish_reason: FinishReason::EndOfSequenceToken, seed: None, };