Adding the legacy handle.

This commit is contained in:
Nicolas Patry 2024-09-25 14:37:36 +02:00
parent cd355d08a9
commit f20ef614bd
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -46,6 +46,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use regex::Regex;
use serde_json::Value;
@ -53,7 +54,7 @@ use std::convert::Infallible;
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use std::path::{Path, PathBuf};
use thiserror::Error;
use tokio::select;
use tokio::signal;
@ -1576,6 +1577,66 @@ pub fn schema() -> ApiDoc {
ApiDoc
}
fn py_resolve_tokenizer<'a>(
py: pyo3::Python<'a>,
tokenizer_name: &str,
revision: Option<&str>,
) -> pyo3::PyResult<()> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name,);
let kwargs = if let Some(rev) = &revision {
[("revision", rev.to_string())].into_py_dict_bound(py)
} else {
pyo3::types::PyDict::new_bound(py)
};
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
Ok(())
}
fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
// XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3
// and state-spaces/mamba-130m
tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization");
#[derive(serde::Deserialize)]
struct FallbackConfig {
base_model_name_or_path: Option<String>,
model_type: Option<String>,
ssm_config: Option<serde_json::Value>,
}
config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<FallbackConfig, _> = serde_json::from_str(c);
if let Ok(config) = config {
if config.model_type.is_none() {
if let Some(base) = config.base_model_name_or_path {
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &base, Some("main"))
})
.ok()?;
}
}
if config.ssm_config.is_some() {
// XXX Legacy mamba
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"))
})
.ok()?;
}
}
Some(())
})
})
}
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
@ -1739,6 +1800,7 @@ pub async fn run(
let tokenizer: Tokenizer = {
use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> {
<<<<<<< HEAD
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
@ -1756,11 +1818,18 @@ pub async fn run(
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
=======
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref())?;
>>>>>>> 744ccdd2 (Adding the legacy handle.)
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
})
.or_else(|err| {
let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err)
})
.expect("We cannot load a tokenizer");
let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {