mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing bad rebase.
This commit is contained in:
parent
bbbd9a6dd2
commit
123ff3a83e
@ -1581,15 +1581,20 @@ fn py_resolve_tokenizer(
|
||||
py: pyo3::Python,
|
||||
tokenizer_name: &str,
|
||||
revision: Option<&str>,
|
||||
trust_remote_code: bool,
|
||||
) -> pyo3::PyResult<()> {
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
let args = (tokenizer_name,);
|
||||
let kwargs = if let Some(rev) = &revision {
|
||||
[("revision", rev.to_string())].into_py_dict_bound(py)
|
||||
[
|
||||
("revision", rev.to_string().into_py(py)),
|
||||
("trust_remote_code", trust_remote_code.into_py(py)),
|
||||
]
|
||||
.into_py_dict_bound(py)
|
||||
} else {
|
||||
pyo3::types::PyDict::new_bound(py)
|
||||
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
|
||||
};
|
||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||
let save = tokenizer.getattr("save_pretrained")?;
|
||||
@ -1619,7 +1624,7 @@ fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
|
||||
if config.model_type.is_none() {
|
||||
if let Some(base) = config.base_model_name_or_path {
|
||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
py_resolve_tokenizer(py, &base, Some("main"))
|
||||
py_resolve_tokenizer(py, &base, Some("main"), false)
|
||||
})
|
||||
.ok()?;
|
||||
}
|
||||
@ -1627,7 +1632,7 @@ fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
|
||||
if config.ssm_config.is_some() {
|
||||
// XXX Legacy mamba
|
||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"))
|
||||
py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"), false)
|
||||
})
|
||||
.ok()?;
|
||||
}
|
||||
@ -1800,27 +1805,7 @@ pub async fn run(
|
||||
let tokenizer: Tokenizer = {
|
||||
use pyo3::prelude::*;
|
||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
<<<<<<< HEAD
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
let args = (tokenizer_name.to_string(),);
|
||||
let kwargs = if let Some(rev) = &revision {
|
||||
[
|
||||
("revision", rev.to_string().into_py(py)),
|
||||
("trust_remote_code", trust_remote_code.into_py(py)),
|
||||
]
|
||||
.into_py_dict_bound(py)
|
||||
} else {
|
||||
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
|
||||
};
|
||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||
let save = tokenizer.getattr("save_pretrained")?;
|
||||
let args = ("out".to_string(),);
|
||||
save.call1(args)?;
|
||||
=======
|
||||
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref())?;
|
||||
>>>>>>> 744ccdd2 (Adding the legacy handle.)
|
||||
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?;
|
||||
Ok(())
|
||||
})
|
||||
.inspect_err(|err| {
|
||||
|
Loading…
Reference in New Issue
Block a user