mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
Remove lambda for cleaner function.
This commit is contained in:
parent
32f6416358
commit
c53968dc45
@ -24,6 +24,43 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
|||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
|
||||||
|
fn get_config(
|
||||||
|
model_id: &str,
|
||||||
|
revision: &Option<String>,
|
||||||
|
) -> Result<Config, Box<dyn std::error::Error>> {
|
||||||
|
let mut path = std::path::Path::new(model_id).to_path_buf();
|
||||||
|
let model_id = model_id.to_string();
|
||||||
|
let filename = if !path.exists() {
|
||||||
|
// Assume it's a hub id
|
||||||
|
|
||||||
|
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||||
|
// env variable has precedence over on file token.
|
||||||
|
ApiBuilder::new().with_token(Some(token)).build()?
|
||||||
|
} else {
|
||||||
|
Api::new()?
|
||||||
|
};
|
||||||
|
let repo = if let Some(ref revision) = revision {
|
||||||
|
api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
api.model(model_id)
|
||||||
|
};
|
||||||
|
repo.get("config.json")?
|
||||||
|
} else {
|
||||||
|
path.push("config.json");
|
||||||
|
path
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = std::fs::read_to_string(filename)?;
|
||||||
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
|
let config: Config = config.into();
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
@ -1504,40 +1541,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
|
|
||||||
tracing::info!("{:#?}", args);
|
tracing::info!("{:#?}", args);
|
||||||
|
|
||||||
let get_config = || -> Result<Config, Box<dyn std::error::Error>> {
|
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
||||||
let model_id = args.model_id.clone();
|
|
||||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
|
||||||
let filename = if !path.exists() {
|
|
||||||
// Assume it's a hub id
|
|
||||||
|
|
||||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
|
||||||
// env variable has precedence over on file token.
|
|
||||||
ApiBuilder::new().with_token(Some(token)).build()?
|
|
||||||
} else {
|
|
||||||
Api::new()?
|
|
||||||
};
|
|
||||||
let repo = if let Some(ref revision) = args.revision {
|
|
||||||
api.repo(Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
revision.to_string(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
api.model(model_id)
|
|
||||||
};
|
|
||||||
repo.get("config.json")?
|
|
||||||
} else {
|
|
||||||
path.push("config.json");
|
|
||||||
path
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
|
||||||
|
|
||||||
let config: Config = config.into();
|
|
||||||
Ok(config)
|
|
||||||
};
|
|
||||||
let config: Option<Config> = get_config().ok();
|
|
||||||
let quantize = config.as_ref().and_then(|c| c.quantize);
|
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||||
// Quantization usually means you're even more RAM constrained.
|
// Quantization usually means you're even more RAM constrained.
|
||||||
let max_default = 4096;
|
let max_default = 4096;
|
||||||
|
Loading…
Reference in New Issue
Block a user