(fix): do not recreate the stateful hashmap at every it

This commit is contained in:
Morgan Funtowicz 2024-10-10 12:41:46 +00:00 committed by Morgan Funtowicz
parent eb13d8d1f3
commit c8a99af6c9

View File

@ -128,7 +128,6 @@ fn executor_status_looper(
} }
} }
// info!("Num response ready: {}", backend.num_responses_ready());
if backend.num_responses_ready() > 0 { if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() { match backend.pin_mut().pull_tokens() {
Ok(responses) => { Ok(responses) => {
@ -172,20 +171,23 @@ fn post_processor_looper(
tokenizer: Tokenizer, tokenizer: Tokenizer,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>, mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) { ) {
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(128);
'post_processor: loop { 'post_processor: loop {
if decoded_tokens.is_closed() { if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now."); warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor; break 'post_processor;
} }
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(128);
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
let state = states.entry(request_id).or_insert(vec![]);
match decoded { match decoded {
Ok(ctx) => { Ok(ctx) => {
state.push(ctx.token.id); states.entry(request_id).and_modify(|s| s.push(*&ctx.token.id)).or_insert_with(|| {
let mut state = Vec::with_capacity(128);
state.push(*&ctx.token.id);
state
});
let out = match tokenizer.decode(&[ctx.token.id], false) { let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => { Ok(text) => {
let is_special = let is_special =
@ -203,10 +205,11 @@ fn post_processor_looper(
top_tokens: vec![], top_tokens: vec![],
} }
} else { } else {
let text = tokenizer.decode(&state, true); let tokens = states.remove(&request_id).unwrap();
let text = tokenizer.decode(&tokens, true);
let generated_text = GeneratedText { let generated_text = GeneratedText {
text: text.unwrap(), text: text.unwrap(),
generated_tokens: state.len() as u32, generated_tokens: tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken, finish_reason: FinishReason::EndOfSequenceToken,
seed: None, seed: None,
}; };