diff --git a/router/src/server.rs b/router/src/server.rs index 010f5fcc..863607b1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1581,15 +1581,20 @@ fn py_resolve_tokenizer( py: pyo3::Python, tokenizer_name: &str, revision: Option<&str>, + trust_remote_code: bool, ) -> pyo3::PyResult<()> { let transformers = py.import_bound("transformers")?; let auto = transformers.getattr("AutoTokenizer")?; let from_pretrained = auto.getattr("from_pretrained")?; let args = (tokenizer_name,); let kwargs = if let Some(rev) = &revision { - [("revision", rev.to_string())].into_py_dict_bound(py) + [ + ("revision", rev.to_string().into_py(py)), + ("trust_remote_code", trust_remote_code.into_py(py)), + ] + .into_py_dict_bound(py) } else { - pyo3::types::PyDict::new_bound(py) + [("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py) }; let tokenizer = from_pretrained.call(args, Some(&kwargs))?; let save = tokenizer.getattr("save_pretrained")?; @@ -1619,7 +1624,7 @@ fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { if config.model_type.is_none() { if let Some(base) = config.base_model_name_or_path { pyo3::Python::with_gil(|py| -> PyResult<()> { - py_resolve_tokenizer(py, &base, Some("main")) + py_resolve_tokenizer(py, &base, Some("main"), false) }) .ok()?; } @@ -1627,7 +1632,7 @@ fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { if config.ssm_config.is_some() { // XXX Legacy mamba pyo3::Python::with_gil(|py| -> PyResult<()> { - py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main")) + py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"), false) }) .ok()?; } @@ -1800,27 +1805,7 @@ pub async fn run( let tokenizer: Tokenizer = { use pyo3::prelude::*; pyo3::Python::with_gil(|py| -> PyResult<()> { -<<<<<<< HEAD - let transformers = py.import_bound("transformers")?; - let auto = transformers.getattr("AutoTokenizer")?; - let from_pretrained = auto.getattr("from_pretrained")?; - let args = (tokenizer_name.to_string(),); - let kwargs = if let Some(rev) = &revision { - [ - ("revision", rev.to_string().into_py(py)), - ("trust_remote_code", trust_remote_code.into_py(py)), - ] - .into_py_dict_bound(py) - } else { - [("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py) - }; - let tokenizer = from_pretrained.call(args, Some(&kwargs))?; - let save = tokenizer.getattr("save_pretrained")?; - let args = ("out".to_string(),); - save.call1(args)?; -======= - py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref())?; ->>>>>>> 744ccdd2 (Adding the legacy handle.) + py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?; Ok(()) }) .inspect_err(|err| {