No unwrap.

This commit is contained in:
Nicolas Patry 2024-04-11 17:23:08 +00:00
parent a4c86e8678
commit 9ce9f39dea

View File

@ -1,4 +1,5 @@
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use nix::sys::signal::{self, Signal}; use nix::sys::signal::{self, Signal};
use nix::unistd::Pid; use nix::unistd::Pid;
use serde::Deserialize; use serde::Deserialize;
@ -19,6 +20,11 @@ use tracing_subscriber::EnvFilter;
mod env_runtime; mod env_runtime;
#[derive(Deserialize)]
struct Config {
max_position_embeddings: usize,
}
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization { enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model: /// 4 bit quantization. Requires a specific AWQ quantized model:
@ -1260,19 +1266,12 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:?}", args); tracing::info!("{:?}", args);
use hf_hub::{api::sync::Api, Repo, RepoType}; let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
#[derive(Deserialize)]
struct Config {
max_position_embeddings: usize,
}
let config: Config = {
let model_id = args.model_id.clone(); let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf(); let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() { let filename = if !path.exists() {
// Assume it's a hub id // Assume it's a hub id
let api = Api::new().unwrap(); let api = Api::new()?;
let repo = if let Some(ref revision) = args.revision { let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision( api.repo(Repo::with_revision(
model_id, model_id,
@ -1282,14 +1281,14 @@ fn main() -> Result<(), LauncherError> {
} else { } else {
api.model(model_id) api.model(model_id)
}; };
repo.get("config.json").unwrap() repo.get("config.json")?
} else { } else {
path.push("config.json"); path.push("config.json");
path path
}; };
let content = std::fs::read_to_string(filename).unwrap(); let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content).unwrap(); let config: Config = serde_json::from_str(&content)?;
let max_default = 2usize.pow(14); let max_default = 2usize.pow(14);
@ -1300,11 +1299,9 @@ fn main() -> Result<(), LauncherError> {
} else { } else {
config.max_position_embeddings config.max_position_embeddings
}; };
Ok(max_position_embeddings)
Config {
max_position_embeddings,
}
}; };
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
let max_input_tokens = { let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) { match (args.max_input_tokens, args.max_input_length) {
@ -1315,7 +1312,7 @@ fn main() -> Result<(), LauncherError> {
} }
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
(None, None) => { (None, None) => {
let value = config.max_position_embeddings - 1; let value = max_position_embeddings - 1;
tracing::info!("Default `max_input_tokens` to {value}"); tracing::info!("Default `max_input_tokens` to {value}");
value value
} }
@ -1325,7 +1322,7 @@ fn main() -> Result<(), LauncherError> {
match args.max_total_tokens { match args.max_total_tokens {
Some(max_total_tokens) => max_total_tokens, Some(max_total_tokens) => max_total_tokens,
None => { None => {
let value = config.max_position_embeddings; let value = max_position_embeddings;
tracing::info!("Default `max_total_tokens` to {value}"); tracing::info!("Default `max_total_tokens` to {value}");
value value
} }