Fixing bad rebase.

This commit is contained in:
Nicolas Patry 2024-10-25 09:58:46 +02:00
parent bbbd9a6dd2
commit 123ff3a83e
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -1581,15 +1581,20 @@ fn py_resolve_tokenizer(
py: pyo3::Python,
tokenizer_name: &str,
revision: Option<&str>,
trust_remote_code: bool,
) -> pyo3::PyResult<()> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name,);
let kwargs = if let Some(rev) = &revision {
[("revision", rev.to_string())].into_py_dict_bound(py)
[
("revision", rev.to_string().into_py(py)),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py)
} else {
pyo3::types::PyDict::new_bound(py)
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
};
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
@ -1619,7 +1624,7 @@ fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
if config.model_type.is_none() {
if let Some(base) = config.base_model_name_or_path {
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &base, Some("main"))
py_resolve_tokenizer(py, &base, Some("main"), false)
})
.ok()?;
}
@ -1627,7 +1632,7 @@ fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
if config.ssm_config.is_some() {
// XXX Legacy mamba
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"))
py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"), false)
})
.ok()?;
}
@ -1800,27 +1805,7 @@ pub async fn run(
let tokenizer: Tokenizer = {
use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> {
<<<<<<< HEAD
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = if let Some(rev) = &revision {
[
("revision", rev.to_string().into_py(py)),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py)
} else {
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
};
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
=======
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref())?;
>>>>>>> 744ccdd2 (Adding the legacy handle.)
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?;
Ok(())
})
.inspect_err(|err| {