Fixing the GIL locking.

This commit is contained in:
Nicolas Patry 2024-09-25 01:18:05 +02:00
parent c0151cc14a
commit 9d7a95b24b
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
2 changed files with 107 additions and 77 deletions

View File

@ -31,70 +31,72 @@ pub enum Tokenizer {
Rust(tokenizers::Tokenizer),
}
impl Tokenizer {
fn into_owned<'a>(self, py: Python<'a>) -> OwnedTokenizer<'a> {
match self {
Self::Python {
tokenizer_name,
revision,
} => {
let pytok = || -> pyo3::PyResult<pyo3::Bound<'a, pyo3::PyAny>> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = if let Some(rev) = &revision {
[("revision", rev.to_string())].into_py_dict_bound(py)
} else {
pyo3::types::PyDict::new_bound(py)
};
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
Ok(tokenizer)
}()
.expect("Cannot load the tokenizer");
tracing::info!("Loaded a python tokenizer");
OwnedTokenizer::Python(pytok)
}
Self::Rust(tok) => OwnedTokenizer::Rust(tok),
}
pub struct PyTokenizer<'a>(pyo3::Bound<'a, pyo3::PyAny>);
impl<'a> PyTokenizer<'a> {
fn from_py(
py: Python<'a>,
tokenizer_name: String,
revision: Option<String>,
) -> PyResult<PyTokenizer<'a>> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name,);
let kwargs = if let Some(rev) = &revision {
[("revision", rev.to_string())].into_py_dict_bound(py)
} else {
pyo3::types::PyDict::new_bound(py)
};
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
tracing::info!("Loaded a python tokenizer");
Ok(PyTokenizer(tokenizer))
}
}
pub enum OwnedTokenizer<'a> {
Python(pyo3::Bound<'a, pyo3::PyAny>),
Rust(tokenizers::Tokenizer),
trait TokenizerTrait {
fn encode_trait(
&self,
query: String,
add_special_tokens: bool,
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>>;
}
impl<'a> OwnedTokenizer<'a> {
fn encode(
impl TokenizerTrait for tokenizers::Tokenizer {
fn encode_trait(
&self,
query: String,
add_special_tokens: bool,
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
match self {
Self::Python(pytok) => {
let py = pytok.py();
let kwargs = [
("text", query.into_py(py)),
("add_special_tokens", add_special_tokens.into_py(py)),
]
.into_py_dict_bound(py);
let encode = pytok.getattr("encode")?;
let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?;
Ok(Encoding::new(
input_ids,
vec![], // type ids
vec![], // tokens (strings)
vec![], // words
vec![], // offsets
vec![], // special_tokens_mask
vec![], // attention_mask
vec![], // overflowing
std::collections::HashMap::new(), //sequence_ranges
))
}
Self::Rust(tok) => tok.encode(query, add_special_tokens),
}
self.encode(query, add_special_tokens)
}
}
impl<'a> TokenizerTrait for PyTokenizer<'a> {
fn encode_trait(
&self,
query: String,
add_special_tokens: bool,
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
let py = self.0.py();
let kwargs = [
("text", query.into_py(py)),
("add_special_tokens", add_special_tokens.into_py(py)),
]
.into_py_dict_bound(py);
let encode = self.0.getattr("encode")?;
let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?;
Ok(Encoding::new(
input_ids,
vec![], // type ids
vec![], // tokens (strings)
vec![], // words
vec![], // offsets
vec![], // special_tokens_mask
vec![], // attention_mask
vec![], // overflowing
std::collections::HashMap::new(), //sequence_ranges
))
}
}

View File

@ -3,8 +3,9 @@ use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
TokenizerTrait,
};
use crate::{OwnedTokenizer, Tokenizer};
use crate::{PyTokenizer, Tokenizer};
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema};
@ -434,26 +435,53 @@ fn tokenizer_worker(
preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
pyo3::Python::with_gil(|py| {
let tokenizer = tokenizer.into_owned(py);
// Loop over requests
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(())
match tokenizer {
Tokenizer::Python {
tokenizer_name,
revision,
} => {
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?;
// Loop over requests
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(())
})
}
Ok(())
})
.expect("Failure in python tokenizer worker");
}
});
Tokenizer::Rust(tokenizer) => {
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(())
})
}
}
}
}
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
@ -581,11 +609,11 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
}
/// Get input length and optionally truncate it
fn prepare_input(
fn prepare_input<T: TokenizerTrait>(
inputs: String,
_truncate: Option<usize>,
add_special_tokens: bool,
tokenizer: &OwnedTokenizer,
tokenizer: &T,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
@ -622,7 +650,7 @@ fn prepare_input(
// Get the number of tokens in the input
let encoding = tokenizer
.encode(tokenizer_query, add_special_tokens)
.encode_trait(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, input_chunks))