From 9ce9f39deaa96472273f7e8d5b87daeca8fdd532 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 11 Apr 2024 17:23:08 +0000 Subject: [PATCH] No unwrap. --- launcher/src/main.rs | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 580f7476..3f8bd424 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,4 +1,5 @@ use clap::{Parser, ValueEnum}; +use hf_hub::{api::sync::Api, Repo, RepoType}; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; use serde::Deserialize; @@ -19,6 +20,11 @@ use tracing_subscriber::EnvFilter; mod env_runtime; +#[derive(Deserialize)] +struct Config { + max_position_embeddings: usize, +} + #[derive(Clone, Copy, Debug, ValueEnum)] enum Quantization { /// 4 bit quantization. Requires a specific AWQ quantized model: @@ -1260,19 +1266,12 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:?}", args); - use hf_hub::{api::sync::Api, Repo, RepoType}; - - #[derive(Deserialize)] - struct Config { - max_position_embeddings: usize, - } - - let config: Config = { + let get_max_position_embeddings = || -> Result> { let model_id = args.model_id.clone(); let mut path = std::path::Path::new(&args.model_id).to_path_buf(); let filename = if !path.exists() { // Assume it's a hub id - let api = Api::new().unwrap(); + let api = Api::new()?; let repo = if let Some(ref revision) = args.revision { api.repo(Repo::with_revision( model_id, @@ -1282,14 +1281,14 @@ fn main() -> Result<(), LauncherError> { } else { api.model(model_id) }; - repo.get("config.json").unwrap() + repo.get("config.json")? } else { path.push("config.json"); path }; - let content = std::fs::read_to_string(filename).unwrap(); - let config: Config = serde_json::from_str(&content).unwrap(); + let content = std::fs::read_to_string(filename)?; + let config: Config = serde_json::from_str(&content)?; let max_default = 2usize.pow(14); @@ -1300,11 +1299,9 @@ fn main() -> Result<(), LauncherError> { } else { config.max_position_embeddings }; - - Config { - max_position_embeddings, - } + Ok(max_position_embeddings) }; + let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) { @@ -1315,7 +1312,7 @@ fn main() -> Result<(), LauncherError> { } (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, (None, None) => { - let value = config.max_position_embeddings - 1; + let value = max_position_embeddings - 1; tracing::info!("Default `max_input_tokens` to {value}"); value } @@ -1325,7 +1322,7 @@ fn main() -> Result<(), LauncherError> { match args.max_total_tokens { Some(max_total_tokens) => max_total_tokens, None => { - let value = config.max_position_embeddings; + let value = max_position_embeddings; tracing::info!("Default `max_total_tokens` to {value}"); value }