mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Adding the legacy handle.
This commit is contained in:
parent
cd355d08a9
commit
f20ef614bd
@ -46,6 +46,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::IntoPyDict;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
@ -53,7 +54,7 @@ use std::convert::Infallible;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
use tokio::select;
|
||||
use tokio::signal;
|
||||
@ -1576,6 +1577,66 @@ pub fn schema() -> ApiDoc {
|
||||
ApiDoc
|
||||
}
|
||||
|
||||
fn py_resolve_tokenizer<'a>(
|
||||
py: pyo3::Python<'a>,
|
||||
tokenizer_name: &str,
|
||||
revision: Option<&str>,
|
||||
) -> pyo3::PyResult<()> {
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
let args = (tokenizer_name,);
|
||||
let kwargs = if let Some(rev) = &revision {
|
||||
[("revision", rev.to_string())].into_py_dict_bound(py)
|
||||
} else {
|
||||
pyo3::types::PyDict::new_bound(py)
|
||||
};
|
||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||
let save = tokenizer.getattr("save_pretrained")?;
|
||||
let args = ("out".to_string(),);
|
||||
save.call1(args)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
|
||||
// XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3
|
||||
// and state-spaces/mamba-130m
|
||||
tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization");
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct FallbackConfig {
|
||||
base_model_name_or_path: Option<String>,
|
||||
model_type: Option<String>,
|
||||
ssm_config: Option<serde_json::Value>,
|
||||
}
|
||||
config_filename.and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<FallbackConfig, _> = serde_json::from_str(c);
|
||||
if let Ok(config) = config {
|
||||
if config.model_type.is_none() {
|
||||
if let Some(base) = config.base_model_name_or_path {
|
||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
py_resolve_tokenizer(py, &base, Some("main"))
|
||||
})
|
||||
.ok()?;
|
||||
}
|
||||
}
|
||||
if config.ssm_config.is_some() {
|
||||
// XXX Legacy mamba
|
||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"))
|
||||
})
|
||||
.ok()?;
|
||||
}
|
||||
}
|
||||
Some(())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
@ -1739,6 +1800,7 @@ pub async fn run(
|
||||
let tokenizer: Tokenizer = {
|
||||
use pyo3::prelude::*;
|
||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
<<<<<<< HEAD
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
@ -1756,11 +1818,18 @@ pub async fn run(
|
||||
let save = tokenizer.getattr("save_pretrained")?;
|
||||
let args = ("out".to_string(),);
|
||||
save.call1(args)?;
|
||||
=======
|
||||
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref())?;
|
||||
>>>>>>> 744ccdd2 (Adding the legacy handle.)
|
||||
Ok(())
|
||||
})
|
||||
.inspect_err(|err| {
|
||||
tracing::error!("Failed to import python tokenizer {err}");
|
||||
})
|
||||
.or_else(|err| {
|
||||
let out = legacy_tokenizer_handle(config_filename.as_ref());
|
||||
out.ok_or(err)
|
||||
})
|
||||
.expect("We cannot load a tokenizer");
|
||||
let filename = "out/tokenizer.json";
|
||||
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
|
||||
|
Loading…
Reference in New Issue
Block a user