mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing the GIL locking.
This commit is contained in:
parent
c0151cc14a
commit
9d7a95b24b
@ -31,70 +31,72 @@ pub enum Tokenizer {
|
||||
Rust(tokenizers::Tokenizer),
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
fn into_owned<'a>(self, py: Python<'a>) -> OwnedTokenizer<'a> {
|
||||
match self {
|
||||
Self::Python {
|
||||
tokenizer_name,
|
||||
revision,
|
||||
} => {
|
||||
let pytok = || -> pyo3::PyResult<pyo3::Bound<'a, pyo3::PyAny>> {
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
let args = (tokenizer_name.to_string(),);
|
||||
let kwargs = if let Some(rev) = &revision {
|
||||
[("revision", rev.to_string())].into_py_dict_bound(py)
|
||||
} else {
|
||||
pyo3::types::PyDict::new_bound(py)
|
||||
};
|
||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||
Ok(tokenizer)
|
||||
}()
|
||||
.expect("Cannot load the tokenizer");
|
||||
tracing::info!("Loaded a python tokenizer");
|
||||
OwnedTokenizer::Python(pytok)
|
||||
}
|
||||
Self::Rust(tok) => OwnedTokenizer::Rust(tok),
|
||||
}
|
||||
pub struct PyTokenizer<'a>(pyo3::Bound<'a, pyo3::PyAny>);
|
||||
|
||||
impl<'a> PyTokenizer<'a> {
|
||||
fn from_py(
|
||||
py: Python<'a>,
|
||||
tokenizer_name: String,
|
||||
revision: Option<String>,
|
||||
) -> PyResult<PyTokenizer<'a>> {
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
let args = (tokenizer_name,);
|
||||
let kwargs = if let Some(rev) = &revision {
|
||||
[("revision", rev.to_string())].into_py_dict_bound(py)
|
||||
} else {
|
||||
pyo3::types::PyDict::new_bound(py)
|
||||
};
|
||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||
tracing::info!("Loaded a python tokenizer");
|
||||
Ok(PyTokenizer(tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
pub enum OwnedTokenizer<'a> {
|
||||
Python(pyo3::Bound<'a, pyo3::PyAny>),
|
||||
Rust(tokenizers::Tokenizer),
|
||||
trait TokenizerTrait {
|
||||
fn encode_trait(
|
||||
&self,
|
||||
query: String,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
|
||||
impl<'a> OwnedTokenizer<'a> {
|
||||
fn encode(
|
||||
impl TokenizerTrait for tokenizers::Tokenizer {
|
||||
fn encode_trait(
|
||||
&self,
|
||||
query: String,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
Self::Python(pytok) => {
|
||||
let py = pytok.py();
|
||||
let kwargs = [
|
||||
("text", query.into_py(py)),
|
||||
("add_special_tokens", add_special_tokens.into_py(py)),
|
||||
]
|
||||
.into_py_dict_bound(py);
|
||||
let encode = pytok.getattr("encode")?;
|
||||
let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?;
|
||||
Ok(Encoding::new(
|
||||
input_ids,
|
||||
vec![], // type ids
|
||||
vec![], // tokens (strings)
|
||||
vec![], // words
|
||||
vec![], // offsets
|
||||
vec![], // special_tokens_mask
|
||||
vec![], // attention_mask
|
||||
vec![], // overflowing
|
||||
std::collections::HashMap::new(), //sequence_ranges
|
||||
))
|
||||
}
|
||||
Self::Rust(tok) => tok.encode(query, add_special_tokens),
|
||||
}
|
||||
self.encode(query, add_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TokenizerTrait for PyTokenizer<'a> {
|
||||
fn encode_trait(
|
||||
&self,
|
||||
query: String,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let py = self.0.py();
|
||||
let kwargs = [
|
||||
("text", query.into_py(py)),
|
||||
("add_special_tokens", add_special_tokens.into_py(py)),
|
||||
]
|
||||
.into_py_dict_bound(py);
|
||||
let encode = self.0.getattr("encode")?;
|
||||
let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?;
|
||||
Ok(Encoding::new(
|
||||
input_ids,
|
||||
vec![], // type ids
|
||||
vec![], // tokens (strings)
|
||||
vec![], // words
|
||||
vec![], // offsets
|
||||
vec![], // special_tokens_mask
|
||||
vec![], // attention_mask
|
||||
vec![], // overflowing
|
||||
std::collections::HashMap::new(), //sequence_ranges
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,8 +3,9 @@ use crate::config::Config;
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{
|
||||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||
TokenizerTrait,
|
||||
};
|
||||
use crate::{OwnedTokenizer, Tokenizer};
|
||||
use crate::{PyTokenizer, Tokenizer};
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{ImageFormat, ImageReader};
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
@ -434,26 +435,53 @@ fn tokenizer_worker(
|
||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
) {
|
||||
pyo3::Python::with_gil(|py| {
|
||||
let tokenizer = tokenizer.into_owned(py);
|
||||
// Loop over requests
|
||||
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||
receiver.blocking_recv()
|
||||
{
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(
|
||||
inputs,
|
||||
truncate,
|
||||
add_special_tokens,
|
||||
&tokenizer,
|
||||
config.as_ref(),
|
||||
preprocessor_config.as_ref(),
|
||||
))
|
||||
.unwrap_or(())
|
||||
match tokenizer {
|
||||
Tokenizer::Python {
|
||||
tokenizer_name,
|
||||
revision,
|
||||
} => {
|
||||
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
|
||||
let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?;
|
||||
// Loop over requests
|
||||
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||
receiver.blocking_recv()
|
||||
{
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(
|
||||
inputs,
|
||||
truncate,
|
||||
add_special_tokens,
|
||||
&tokenizer,
|
||||
config.as_ref(),
|
||||
preprocessor_config.as_ref(),
|
||||
))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.expect("Failure in python tokenizer worker");
|
||||
}
|
||||
});
|
||||
Tokenizer::Rust(tokenizer) => {
|
||||
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||
receiver.blocking_recv()
|
||||
{
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(
|
||||
inputs,
|
||||
truncate,
|
||||
add_special_tokens,
|
||||
&tokenizer,
|
||||
config.as_ref(),
|
||||
preprocessor_config.as_ref(),
|
||||
))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
||||
@ -581,11 +609,11 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||
}
|
||||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
fn prepare_input<T: TokenizerTrait>(
|
||||
inputs: String,
|
||||
_truncate: Option<usize>,
|
||||
add_special_tokens: bool,
|
||||
tokenizer: &OwnedTokenizer,
|
||||
tokenizer: &T,
|
||||
config: Option<&Config>,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||
@ -622,7 +650,7 @@ fn prepare_input(
|
||||
|
||||
// Get the number of tokens in the input
|
||||
let encoding = tokenizer
|
||||
.encode(tokenizer_query, add_special_tokens)
|
||||
.encode_trait(tokenizer_query, add_special_tokens)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
Ok((encoding, input_chunks))
|
||||
|
Loading…
Reference in New Issue
Block a user