From 13fe82264bef6039c2963345cab8f05d844aab1f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 17 Oct 2024 10:58:07 +0200 Subject: [PATCH] Fixing "deadlock" when python prompts for trust_remote_code by always specifiying a value. --- backends/v2/src/main.rs | 4 ++++ backends/v3/src/main.rs | 4 ++++ launcher/src/main.rs | 4 ++++ router/src/server.rs | 12 ++++++++---- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index f53d898e..73330b10 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -44,6 +44,8 @@ struct Args { tokenizer_config_path: Option, #[clap(long, env)] revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(long, env)] @@ -101,6 +103,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_name, tokenizer_config_path, revision, + trust_remote_code, validation_workers, api_key, json_output, @@ -184,6 +187,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_name, tokenizer_config_path, revision, + trust_remote_code, hostname, port, cors_allow_origin, diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index b4751bd5..74227386 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -44,6 +44,8 @@ struct Args { tokenizer_config_path: Option, #[clap(long, env)] revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(long, env)] @@ -101,6 +103,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_name, tokenizer_config_path, revision, + trust_remote_code, validation_workers, api_key, json_output, @@ -184,6 +187,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_name, tokenizer_config_path, revision, + trust_remote_code, hostname, port, cors_allow_origin, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d9f569fd..98f4621f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1507,6 +1507,10 @@ fn spawn_webserver( router_args.push(revision.to_string()) } + if args.trust_remote_code { + router_args.push("--trust-remote-code".to_string()); + } + if args.json_output { router_args.push("--json-output".to_string()); } diff --git a/router/src/server.rs b/router/src/server.rs index 5e6e6960..403f23ab 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1601,6 +1601,7 @@ pub async fn run( tokenizer_name: String, tokenizer_config_path: Option, revision: Option, + trust_remote_code: bool, hostname: String, port: u16, cors_allow_origin: Option>, @@ -1761,10 +1762,13 @@ pub async fn run( let auto = transformers.getattr("AutoTokenizer")?; let from_pretrained = auto.getattr("from_pretrained")?; let args = (tokenizer_name.to_string(),); - let kwargs = [( - "revision", - revision.clone().unwrap_or_else(|| "main".to_string()), - )] + let kwargs = [ + ( + "revision", + (revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py), + ), + ("trust_remote_code", trust_remote_code.into_py(py)), + ] .into_py_dict_bound(py); let tokenizer = from_pretrained.call(args, Some(&kwargs))?; let save = tokenizer.getattr("save_pretrained")?;