mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
No unwrap.
This commit is contained in:
parent
a4c86e8678
commit
9ce9f39dea
@ -1,4 +1,5 @@
|
|||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use nix::sys::signal::{self, Signal};
|
use nix::sys::signal::{self, Signal};
|
||||||
use nix::unistd::Pid;
|
use nix::unistd::Pid;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@ -19,6 +20,11 @@ use tracing_subscriber::EnvFilter;
|
|||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Config {
|
||||||
|
max_position_embeddings: usize,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum Quantization {
|
enum Quantization {
|
||||||
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||||
@ -1260,19 +1266,12 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
|
|
||||||
tracing::info!("{:?}", args);
|
tracing::info!("{:?}", args);
|
||||||
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct Config {
|
|
||||||
max_position_embeddings: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
let config: Config = {
|
|
||||||
let model_id = args.model_id.clone();
|
let model_id = args.model_id.clone();
|
||||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||||
let filename = if !path.exists() {
|
let filename = if !path.exists() {
|
||||||
// Assume it's a hub id
|
// Assume it's a hub id
|
||||||
let api = Api::new().unwrap();
|
let api = Api::new()?;
|
||||||
let repo = if let Some(ref revision) = args.revision {
|
let repo = if let Some(ref revision) = args.revision {
|
||||||
api.repo(Repo::with_revision(
|
api.repo(Repo::with_revision(
|
||||||
model_id,
|
model_id,
|
||||||
@ -1282,14 +1281,14 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
} else {
|
} else {
|
||||||
api.model(model_id)
|
api.model(model_id)
|
||||||
};
|
};
|
||||||
repo.get("config.json").unwrap()
|
repo.get("config.json")?
|
||||||
} else {
|
} else {
|
||||||
path.push("config.json");
|
path.push("config.json");
|
||||||
path
|
path
|
||||||
};
|
};
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename).unwrap();
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: Config = serde_json::from_str(&content).unwrap();
|
let config: Config = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
let max_default = 2usize.pow(14);
|
let max_default = 2usize.pow(14);
|
||||||
|
|
||||||
@ -1300,11 +1299,9 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
} else {
|
} else {
|
||||||
config.max_position_embeddings
|
config.max_position_embeddings
|
||||||
};
|
};
|
||||||
|
Ok(max_position_embeddings)
|
||||||
Config {
|
|
||||||
max_position_embeddings,
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||||
|
|
||||||
let max_input_tokens = {
|
let max_input_tokens = {
|
||||||
match (args.max_input_tokens, args.max_input_length) {
|
match (args.max_input_tokens, args.max_input_length) {
|
||||||
@ -1315,7 +1312,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
}
|
}
|
||||||
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
||||||
(None, None) => {
|
(None, None) => {
|
||||||
let value = config.max_position_embeddings - 1;
|
let value = max_position_embeddings - 1;
|
||||||
tracing::info!("Default `max_input_tokens` to {value}");
|
tracing::info!("Default `max_input_tokens` to {value}");
|
||||||
value
|
value
|
||||||
}
|
}
|
||||||
@ -1325,7 +1322,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
match args.max_total_tokens {
|
match args.max_total_tokens {
|
||||||
Some(max_total_tokens) => max_total_tokens,
|
Some(max_total_tokens) => max_total_tokens,
|
||||||
None => {
|
None => {
|
||||||
let value = config.max_position_embeddings;
|
let value = max_position_embeddings;
|
||||||
tracing::info!("Default `max_total_tokens` to {value}");
|
tracing::info!("Default `max_total_tokens` to {value}");
|
||||||
value
|
value
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user