mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Handling potential lack of offsets (python tokenizer)
This commit is contained in:
parent
5ba7805f1c
commit
9d702bcde3
@ -64,6 +64,42 @@ use tracing::{info_span, instrument, Instrument};
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {
|
||||
let offsets = encoding.get_offsets();
|
||||
let input_ids = encoding.get_ids();
|
||||
if offsets.len() == input_ids.len() {
|
||||
encoding
|
||||
.get_ids()
|
||||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, &(start, stop))| {
|
||||
let text = input
|
||||
.chars()
|
||||
.skip(start)
|
||||
.take(stop - start)
|
||||
.collect::<String>();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
start,
|
||||
stop,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
encoding
|
||||
.get_ids()
|
||||
.iter()
|
||||
.map(|&id| SimpleToken {
|
||||
id,
|
||||
text: "".to_string(),
|
||||
start: 0,
|
||||
stop: 0,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
||||
#[utoipa::path(
|
||||
post,
|
||||
@ -161,24 +197,8 @@ async fn get_chat_tokenize(
|
||||
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
||||
let input = generate_request.inputs.clone();
|
||||
let encoding = infer.tokenize(generate_request).await?;
|
||||
let tokens: Vec<SimpleToken> = encoding
|
||||
.get_ids()
|
||||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, &(start, stop))| {
|
||||
let text = input
|
||||
.chars()
|
||||
.skip(start)
|
||||
.take(stop - start)
|
||||
.collect::<String>();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
start,
|
||||
stop,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tokens = encoding_to_tokens(&encoding, &input);
|
||||
|
||||
let resp = ChatTokenizeResponse {
|
||||
tokenize_response: TokenizeResponse(tokens),
|
||||
@ -1448,24 +1468,7 @@ async fn tokenize(
|
||||
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let input = req.inputs.clone();
|
||||
let encoding = infer.tokenize(req).await?;
|
||||
let tokens: Vec<SimpleToken> = encoding
|
||||
.get_ids()
|
||||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, &(start, stop))| {
|
||||
let text = input
|
||||
.chars()
|
||||
.skip(start)
|
||||
.take(stop - start)
|
||||
.collect::<String>();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
start,
|
||||
stop,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let tokens = encoding_to_tokens(&encoding, &input);
|
||||
Ok(Json(TokenizeResponse(tokens)))
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user