fix: add merge-lora arg for model id (#2788)

This commit is contained in:
drbh 2024-12-01 23:52:02 -05:00 committed by GitHub
parent a35d1e6fe5
commit 2c74c55637
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1193,6 +1193,7 @@ fn download_convert_model(
huggingface_hub_cache: Option<&str>, huggingface_hub_cache: Option<&str>,
weights_cache_override: Option<&str>, weights_cache_override: Option<&str>,
running: Arc<AtomicBool>, running: Arc<AtomicBool>,
merge_lora: bool,
) -> Result<(), LauncherError> { ) -> Result<(), LauncherError> {
// Enter download tracing span // Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let _span = tracing::span!(tracing::Level::INFO, "download").entered();
@ -1207,6 +1208,10 @@ fn download_convert_model(
"--json-output".to_string(), "--json-output".to_string(),
]; ];
if merge_lora {
download_args.push("--merge-lora".to_string());
}
// Model optional revision // Model optional revision
if let Some(revision) = &revision { if let Some(revision) = &revision {
download_args.push("--revision".to_string()); download_args.push("--revision".to_string());
@ -1842,6 +1847,7 @@ fn main() -> Result<(), LauncherError> {
args.huggingface_hub_cache.as_deref(), args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(), args.weights_cache_override.as_deref(),
running.clone(), running.clone(),
true, // if its only a lora model - we should merge the lora adapters
)?; )?;
// Download and convert lora adapters if any // Download and convert lora adapters if any
@ -1875,6 +1881,7 @@ fn main() -> Result<(), LauncherError> {
args.huggingface_hub_cache.as_deref(), args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(), args.weights_cache_override.as_deref(),
running.clone(), running.clone(),
false, // avoid merging lora adapters if using multi-lora
)?; )?;
} else { } else {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(