feat(backend): handle all the tokenization failure and send back to the client

This commit is contained in:
Morgan Funtowicz 2024-11-06 17:46:46 +01:00
parent 20652824d9
commit 26d0266cec

View File

@ -124,56 +124,59 @@ fn llama_generate_callback(
let token = match ctx.tokenizer.decode(&[new_token_id], false) { let token = match ctx.tokenizer.decode(&[new_token_id], false) {
Ok(text) => { Ok(text) => {
let special = ctx.tokenizer.get_added_vocabulary().is_special_token(&text); let special = ctx.tokenizer.get_added_vocabulary().is_special_token(&text);
Token { Ok(Token {
id: new_token_id, id: new_token_id,
text, text,
logprob: new_token_logit, logprob: new_token_logit,
special, special,
} })
} }
Err(_) => panic!("Failed to decode token"), Err(ref err) => Err(InferError::GenerationError(err.to_string())),
}; };
// Create the streamed response // Create the streamed response
let response = match is_final { let response = match token {
false => InferStreamResponse::Intermediate { Ok(token) => {
token, match is_final {
top_tokens: vec![], false => Ok(InferStreamResponse::Intermediate {
},
true => {
// Decode the whole text
match ctx
.tokenizer
.decode(&ctx.generation.generated_tokens, false)
{
Ok(text) => InferStreamResponse::End {
token, token,
top_tokens: vec![], top_tokens: vec![],
generated_text: GeneratedText { }),
text, true => {
generated_tokens: n_generated_tokens as u32, // Decode the whole text
finish_reason: FinishReason::Length, match ctx
seed: Some(ctx.generation.sampling_params.seed), .tokenizer
}, .decode(&ctx.generation.generated_tokens, false)
start: ctx.start, {
queued: ctx.start, Ok(text) => Ok(InferStreamResponse::End {
}, token,
Err(_) => panic!("Failed to decode token"), top_tokens: vec![],
generated_text: GeneratedText {
text,
generated_tokens: n_generated_tokens as u32,
finish_reason: FinishReason::Length,
seed: Some(ctx.generation.sampling_params.seed),
},
start: ctx.start,
queued: ctx.start,
}),
Err(err) => Err(InferError::GenerationError(err.to_string())),
}
}
} }
// Stream end response
} }
Err(err) => Err(err),
}; };
// Send back to the client // Send back to the client
if let Err(ref _err) = ctx.stream.send(Ok(response)) { let should_stop = if let Err(ref _err) = ctx.stream.send(response) {
error!("Failed to send back the response to the client, cancelling request"); error!("Failed to send back the response to the client, cancelling request");
// TODO: cancel the request true
return true; // should_stop } else {
} true
};
// should_stop should_stop
false
} }
unsafe fn scheduler_loop( unsafe fn scheduler_loop(