mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Handling potential lack of offsets (python tokenizer)
This commit is contained in:
parent
5ba7805f1c
commit
9d702bcde3
@ -64,6 +64,42 @@ use tracing::{info_span, instrument, Instrument};
|
|||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {
|
||||||
|
let offsets = encoding.get_offsets();
|
||||||
|
let input_ids = encoding.get_ids();
|
||||||
|
if offsets.len() == input_ids.len() {
|
||||||
|
encoding
|
||||||
|
.get_ids()
|
||||||
|
.iter()
|
||||||
|
.zip(encoding.get_offsets())
|
||||||
|
.map(|(&id, &(start, stop))| {
|
||||||
|
let text = input
|
||||||
|
.chars()
|
||||||
|
.skip(start)
|
||||||
|
.take(stop - start)
|
||||||
|
.collect::<String>();
|
||||||
|
SimpleToken {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
start,
|
||||||
|
stop,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
encoding
|
||||||
|
.get_ids()
|
||||||
|
.iter()
|
||||||
|
.map(|&id| SimpleToken {
|
||||||
|
id,
|
||||||
|
text: "".to_string(),
|
||||||
|
start: 0,
|
||||||
|
stop: 0,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
@ -161,24 +197,8 @@ async fn get_chat_tokenize(
|
|||||||
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
||||||
let input = generate_request.inputs.clone();
|
let input = generate_request.inputs.clone();
|
||||||
let encoding = infer.tokenize(generate_request).await?;
|
let encoding = infer.tokenize(generate_request).await?;
|
||||||
let tokens: Vec<SimpleToken> = encoding
|
|
||||||
.get_ids()
|
let tokens = encoding_to_tokens(&encoding, &input);
|
||||||
.iter()
|
|
||||||
.zip(encoding.get_offsets())
|
|
||||||
.map(|(&id, &(start, stop))| {
|
|
||||||
let text = input
|
|
||||||
.chars()
|
|
||||||
.skip(start)
|
|
||||||
.take(stop - start)
|
|
||||||
.collect::<String>();
|
|
||||||
SimpleToken {
|
|
||||||
id,
|
|
||||||
text,
|
|
||||||
start,
|
|
||||||
stop,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let resp = ChatTokenizeResponse {
|
let resp = ChatTokenizeResponse {
|
||||||
tokenize_response: TokenizeResponse(tokens),
|
tokenize_response: TokenizeResponse(tokens),
|
||||||
@ -1448,24 +1468,7 @@ async fn tokenize(
|
|||||||
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let input = req.inputs.clone();
|
let input = req.inputs.clone();
|
||||||
let encoding = infer.tokenize(req).await?;
|
let encoding = infer.tokenize(req).await?;
|
||||||
let tokens: Vec<SimpleToken> = encoding
|
let tokens = encoding_to_tokens(&encoding, &input);
|
||||||
.get_ids()
|
|
||||||
.iter()
|
|
||||||
.zip(encoding.get_offsets())
|
|
||||||
.map(|(&id, &(start, stop))| {
|
|
||||||
let text = input
|
|
||||||
.chars()
|
|
||||||
.skip(start)
|
|
||||||
.take(stop - start)
|
|
||||||
.collect::<String>();
|
|
||||||
SimpleToken {
|
|
||||||
id,
|
|
||||||
text,
|
|
||||||
start,
|
|
||||||
stop,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Ok(Json(TokenizeResponse(tokens)))
|
Ok(Json(TokenizeResponse(tokens)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user