diff --git a/router/src/server.rs b/router/src/server.rs index 8608ca2a..748df83e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -46,6 +46,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; +use pyo3::prelude::*; use pyo3::types::IntoPyDict; use regex::Regex; use serde_json::Value; @@ -53,7 +54,7 @@ use std::convert::Infallible; use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::path::Path; +use std::path::{Path, PathBuf}; use thiserror::Error; use tokio::select; use tokio::signal; @@ -1576,6 +1577,66 @@ pub fn schema() -> ApiDoc { ApiDoc } +fn py_resolve_tokenizer<'a>( + py: pyo3::Python<'a>, + tokenizer_name: &str, + revision: Option<&str>, +) -> pyo3::PyResult<()> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name,); + let kwargs = if let Some(rev) = &revision { + [("revision", rev.to_string())].into_py_dict_bound(py) + } else { + pyo3::types::PyDict::new_bound(py) + }; + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + let save = tokenizer.getattr("save_pretrained")?; + let args = ("out".to_string(),); + save.call1(args)?; + Ok(()) +} + +fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { + // XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3 + // and state-spaces/mamba-130m + tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization"); + + #[derive(serde::Deserialize)] + struct FallbackConfig { + base_model_name_or_path: Option, + model_type: Option, + ssm_config: Option, + } + config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Ok(config) = config { + if config.model_type.is_none() { + if let Some(base) = config.base_model_name_or_path { + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, &base, Some("main")) + }) + .ok()?; + } + } + if config.ssm_config.is_some() { + // XXX Legacy mamba + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main")) + }) + .ok()?; + } + } + Some(()) + }) + }) +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -1739,6 +1800,7 @@ pub async fn run( let tokenizer: Tokenizer = { use pyo3::prelude::*; pyo3::Python::with_gil(|py| -> PyResult<()> { +<<<<<<< HEAD let transformers = py.import_bound("transformers")?; let auto = transformers.getattr("AutoTokenizer")?; let from_pretrained = auto.getattr("from_pretrained")?; @@ -1756,11 +1818,18 @@ pub async fn run( let save = tokenizer.getattr("save_pretrained")?; let args = ("out".to_string(),); save.call1(args)?; +======= + py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref())?; +>>>>>>> 744ccdd2 (Adding the legacy handle.) Ok(()) }) .inspect_err(|err| { tracing::error!("Failed to import python tokenizer {err}"); }) + .or_else(|err| { + let out = legacy_tokenizer_handle(config_filename.as_ref()); + out.ok_or(err) + }) .expect("We cannot load a tokenizer"); let filename = "out/tokenizer.json"; if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {