No unwrap.

This commit is contained in:
Nicolas Patry 2024-04-11 17:23:08 +00:00
parent a4c86e8678
commit 9ce9f39dea

View File

@ -1,4 +1,5 @@
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use serde::Deserialize;
@ -19,6 +20,11 @@ use tracing_subscriber::EnvFilter;
mod env_runtime;
#[derive(Deserialize)]
struct Config {
max_position_embeddings: usize,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model:
@ -1260,19 +1266,12 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:?}", args);
use hf_hub::{api::sync::Api, Repo, RepoType};
#[derive(Deserialize)]
struct Config {
max_position_embeddings: usize,
}
let config: Config = {
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = Api::new().unwrap();
let api = Api::new()?;
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
@ -1282,14 +1281,14 @@ fn main() -> Result<(), LauncherError> {
} else {
api.model(model_id)
};
repo.get("config.json").unwrap()
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename).unwrap();
let config: Config = serde_json::from_str(&content).unwrap();
let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content)?;
let max_default = 2usize.pow(14);
@ -1300,11 +1299,9 @@ fn main() -> Result<(), LauncherError> {
} else {
config.max_position_embeddings
};
Config {
max_position_embeddings,
}
Ok(max_position_embeddings)
};
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
@ -1315,7 +1312,7 @@ fn main() -> Result<(), LauncherError> {
}
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
(None, None) => {
let value = config.max_position_embeddings - 1;
let value = max_position_embeddings - 1;
tracing::info!("Default `max_input_tokens` to {value}");
value
}
@ -1325,7 +1322,7 @@ fn main() -> Result<(), LauncherError> {
match args.max_total_tokens {
Some(max_total_tokens) => max_total_tokens,
None => {
let value = config.max_position_embeddings;
let value = max_position_embeddings;
tracing::info!("Default `max_total_tokens` to {value}");
value
}