From 97f7a22f0b0f57edc840beaf152e7fd102ed8311 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 7 Nov 2024 21:43:38 +0800 Subject: [PATCH] add trust_remote_code in tokenizer to fix baichuan issue (#2725) Signed-off-by: Wang, Yi A --- router/src/lib.rs | 10 ++++++++-- router/src/server.rs | 1 + router/src/validation.rs | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index a5613f89..d9cacb91 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -27,6 +27,7 @@ pub enum Tokenizer { Python { tokenizer_name: String, revision: Option, + trust_remote_code: bool, }, Rust(tokenizers::Tokenizer), } @@ -38,15 +39,20 @@ impl<'a> PyTokenizer<'a> { py: Python<'a>, tokenizer_name: String, revision: Option, + trust_remote_code: bool, ) -> PyResult> { let transformers = py.import_bound("transformers")?; let auto = transformers.getattr("AutoTokenizer")?; let from_pretrained = auto.getattr("from_pretrained")?; let args = (tokenizer_name,); let kwargs = if let Some(rev) = &revision { - [("revision", rev.to_string())].into_py_dict_bound(py) + [ + ("revision", rev.to_string().into_py(py)), + ("trust_remote_code", trust_remote_code.into_py(py)), + ] + .into_py_dict_bound(py) } else { - pyo3::types::PyDict::new_bound(py) + [("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py) }; let tokenizer = from_pretrained.call(args, Some(&kwargs))?; tracing::info!("Loaded a python tokenizer"); diff --git a/router/src/server.rs b/router/src/server.rs index 7d8d518c..2058bce3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1829,6 +1829,7 @@ pub async fn run( Tokenizer::Python { tokenizer_name: tokenizer_name.clone(), revision: revision.clone(), + trust_remote_code, } } }; diff --git a/router/src/validation.rs b/router/src/validation.rs index 5b2a153c..3cd85a6e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -439,9 +439,11 @@ fn tokenizer_worker( Tokenizer::Python { tokenizer_name, revision, + trust_remote_code, } => { pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> { - let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?; + let tokenizer = + PyTokenizer::from_py(py, tokenizer_name, revision, trust_remote_code)?; // Loop over requests while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = receiver.blocking_recv()