From c53968dc457a455c92df1dc199f80e7ef50ebb2f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Aug 2024 15:37:54 +0200 Subject: [PATCH] Remove lambda for cleaner function. --- launcher/src/main.rs | 72 +++++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index faa84db30..0d6662be7 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -24,6 +24,43 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; +fn get_config( + model_id: &str, + revision: &Option, +) -> Result> { + let mut path = std::path::Path::new(model_id).to_path_buf(); + let model_id = model_id.to_string(); + let filename = if !path.exists() { + // Assume it's a hub id + + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? + } else { + path.push("config.json"); + path + }; + + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; + + let config: Config = config.into(); + Ok(config) +} + fn resolve_attention(config: &Option, lora_adapters: &Option) -> (String, String) { let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); @@ -1504,40 +1541,7 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:#?}", args); - let get_config = || -> Result> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id - - let api = if let Ok(token) = std::env::var("HF_TOKEN") { - // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)).build()? - } else { - Api::new()? - }; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; - - let config: Config = config.into(); - Ok(config) - }; - let config: Option = get_config().ok(); + let config: Option = get_config(&args.model_id, &args.revision).ok(); let quantize = config.as_ref().and_then(|c| c.quantize); // Quantization usually means you're even more RAM constrained. let max_default = 4096;