Remove lambda for cleaner function.

This commit is contained in:
Nicolas Patry 2024-08-23 15:37:54 +02:00
parent 32f6416358
commit c53968dc45
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -24,6 +24,43 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime; mod env_runtime;
fn get_config(
model_id: &str,
revision: &Option<String>,
) -> Result<Config, Box<dyn std::error::Error>> {
let mut path = std::path::Path::new(model_id).to_path_buf();
let model_id = model_id.to_string();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
Ok(config)
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) { fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
@ -1504,40 +1541,7 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args); tracing::info!("{:#?}", args);
let get_config = || -> Result<Config, Box<dyn std::error::Error>> { let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
Ok(config)
};
let config: Option<Config> = get_config().ok();
let quantize = config.as_ref().and_then(|c| c.quantize); let quantize = config.as_ref().and_then(|c| c.quantize);
// Quantization usually means you're even more RAM constrained. // Quantization usually means you're even more RAM constrained.
let max_default = 4096; let max_default = 4096;