Fixing the GIL locking.

This commit is contained in:
Nicolas Patry 2024-09-25 01:18:05 +02:00
parent c0151cc14a
commit 9d7a95b24b
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
2 changed files with 107 additions and 77 deletions

View File

@ -31,70 +31,72 @@ pub enum Tokenizer {
Rust(tokenizers::Tokenizer), Rust(tokenizers::Tokenizer),
} }
impl Tokenizer { pub struct PyTokenizer<'a>(pyo3::Bound<'a, pyo3::PyAny>);
fn into_owned<'a>(self, py: Python<'a>) -> OwnedTokenizer<'a> {
match self { impl<'a> PyTokenizer<'a> {
Self::Python { fn from_py(
tokenizer_name, py: Python<'a>,
revision, tokenizer_name: String,
} => { revision: Option<String>,
let pytok = || -> pyo3::PyResult<pyo3::Bound<'a, pyo3::PyAny>> { ) -> PyResult<PyTokenizer<'a>> {
let transformers = py.import_bound("transformers")?; let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?; let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?; let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),); let args = (tokenizer_name,);
let kwargs = if let Some(rev) = &revision { let kwargs = if let Some(rev) = &revision {
[("revision", rev.to_string())].into_py_dict_bound(py) [("revision", rev.to_string())].into_py_dict_bound(py)
} else { } else {
pyo3::types::PyDict::new_bound(py) pyo3::types::PyDict::new_bound(py)
}; };
let tokenizer = from_pretrained.call(args, Some(&kwargs))?; let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
Ok(tokenizer) tracing::info!("Loaded a python tokenizer");
}() Ok(PyTokenizer(tokenizer))
.expect("Cannot load the tokenizer");
tracing::info!("Loaded a python tokenizer");
OwnedTokenizer::Python(pytok)
}
Self::Rust(tok) => OwnedTokenizer::Rust(tok),
}
} }
} }
pub enum OwnedTokenizer<'a> { trait TokenizerTrait {
Python(pyo3::Bound<'a, pyo3::PyAny>), fn encode_trait(
Rust(tokenizers::Tokenizer), &self,
query: String,
add_special_tokens: bool,
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>>;
} }
impl<'a> OwnedTokenizer<'a> { impl TokenizerTrait for tokenizers::Tokenizer {
fn encode( fn encode_trait(
&self, &self,
query: String, query: String,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
match self { self.encode(query, add_special_tokens)
Self::Python(pytok) => { }
let py = pytok.py(); }
let kwargs = [
("text", query.into_py(py)), impl<'a> TokenizerTrait for PyTokenizer<'a> {
("add_special_tokens", add_special_tokens.into_py(py)), fn encode_trait(
] &self,
.into_py_dict_bound(py); query: String,
let encode = pytok.getattr("encode")?; add_special_tokens: bool,
let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?; ) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
Ok(Encoding::new( let py = self.0.py();
input_ids, let kwargs = [
vec![], // type ids ("text", query.into_py(py)),
vec![], // tokens (strings) ("add_special_tokens", add_special_tokens.into_py(py)),
vec![], // words ]
vec![], // offsets .into_py_dict_bound(py);
vec![], // special_tokens_mask let encode = self.0.getattr("encode")?;
vec![], // attention_mask let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?;
vec![], // overflowing Ok(Encoding::new(
std::collections::HashMap::new(), //sequence_ranges input_ids,
)) vec![], // type ids
} vec![], // tokens (strings)
Self::Rust(tok) => tok.encode(query, add_special_tokens), vec![], // words
} vec![], // offsets
vec![], // special_tokens_mask
vec![], // attention_mask
vec![], // overflowing
std::collections::HashMap::new(), //sequence_ranges
))
} }
} }

View File

@ -3,8 +3,9 @@ use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{ use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
TokenizerTrait,
}; };
use crate::{OwnedTokenizer, Tokenizer}; use crate::{PyTokenizer, Tokenizer};
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader}; use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
@ -434,26 +435,53 @@ fn tokenizer_worker(
preprocessor_config: Option<HubPreprocessorConfig>, preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) { ) {
pyo3::Python::with_gil(|py| { match tokenizer {
let tokenizer = tokenizer.into_owned(py); Tokenizer::Python {
// Loop over requests tokenizer_name,
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = revision,
receiver.blocking_recv() } => {
{ pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
parent_span.in_scope(|| { let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?;
response_tx // Loop over requests
.send(prepare_input( while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
inputs, receiver.blocking_recv()
truncate, {
add_special_tokens, parent_span.in_scope(|| {
&tokenizer, response_tx
config.as_ref(), .send(prepare_input(
preprocessor_config.as_ref(), inputs,
)) truncate,
.unwrap_or(()) add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(())
})
}
Ok(())
}) })
.expect("Failure in python tokenizer worker");
} }
}); Tokenizer::Rust(tokenizer) => {
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(())
})
}
}
}
} }
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> { fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
@ -581,11 +609,11 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
} }
/// Get input length and optionally truncate it /// Get input length and optionally truncate it
fn prepare_input( fn prepare_input<T: TokenizerTrait>(
inputs: String, inputs: String,
_truncate: Option<usize>, _truncate: Option<usize>,
add_special_tokens: bool, add_special_tokens: bool,
tokenizer: &OwnedTokenizer, tokenizer: &T,
config: Option<&Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>, preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
@ -622,7 +650,7 @@ fn prepare_input(
// Get the number of tokens in the input // Get the number of tokens in the input
let encoding = tokenizer let encoding = tokenizer
.encode(tokenizer_query, add_special_tokens) .encode_trait(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, input_chunks)) Ok((encoding, input_chunks))