diff --git a/router/src/server.rs b/router/src/server.rs index 2f86bcb6..7181548b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -64,6 +64,42 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec { + let offsets = encoding.get_offsets(); + let input_ids = encoding.get_ids(); + if offsets.len() == input_ids.len() { + encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect() + } else { + encoding + .get_ids() + .iter() + .map(|&id| SimpleToken { + id, + text: "".to_string(), + start: 0, + stop: 0, + }) + .collect() + } +} + /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( post, @@ -161,24 +197,8 @@ async fn get_chat_tokenize( let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0; let input = generate_request.inputs.clone(); let encoding = infer.tokenize(generate_request).await?; - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect(); + + let tokens = encoding_to_tokens(&encoding, &input); let resp = ChatTokenizeResponse { tokenize_response: TokenizeResponse(tokens), @@ -1448,24 +1468,7 @@ async fn tokenize( ) -> Result, (StatusCode, Json)> { let input = req.inputs.clone(); let encoding = infer.tokenize(req).await?; - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect(); + let tokens = encoding_to_tokens(&encoding, &input); Ok(Json(TokenizeResponse(tokens))) }