Handling potential lack of offsets (python tokenizer)

This commit is contained in:
Nicolas Patry 2024-09-17 16:56:19 +02:00
parent 5ba7805f1c
commit 9d702bcde3
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -64,6 +64,42 @@ use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {
let offsets = encoding.get_offsets();
let input_ids = encoding.get_ids();
if offsets.len() == input_ids.len() {
encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect()
} else {
encoding
.get_ids()
.iter()
.map(|&id| SimpleToken {
id,
text: "".to_string(),
start: 0,
stop: 0,
})
.collect()
}
}
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
post,
@ -161,24 +197,8 @@ async fn get_chat_tokenize(
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
let input = generate_request.inputs.clone();
let encoding = infer.tokenize(generate_request).await?;
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
let tokens = encoding_to_tokens(&encoding, &input);
let resp = ChatTokenizeResponse {
tokenize_response: TokenizeResponse(tokens),
@ -1448,24 +1468,7 @@ async fn tokenize(
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
let input = req.inputs.clone();
let encoding = infer.tokenize(req).await?;
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
let tokens = encoding_to_tokens(&encoding, &input);
Ok(Json(TokenizeResponse(tokens)))
}