diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fc40bdb1..72bd0ebd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1193,6 +1193,7 @@ fn download_convert_model( huggingface_hub_cache: Option<&str>, weights_cache_override: Option<&str>, running: Arc, + merge_lora: bool, ) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); @@ -1207,6 +1208,10 @@ fn download_convert_model( "--json-output".to_string(), ]; + if merge_lora { + download_args.push("--merge-lora".to_string()); + } + // Model optional revision if let Some(revision) = &revision { download_args.push("--revision".to_string()); @@ -1842,6 +1847,7 @@ fn main() -> Result<(), LauncherError> { args.huggingface_hub_cache.as_deref(), args.weights_cache_override.as_deref(), running.clone(), + true, // if its only a lora model - we should merge the lora adapters )?; // Download and convert lora adapters if any @@ -1875,6 +1881,7 @@ fn main() -> Result<(), LauncherError> { args.huggingface_hub_cache.as_deref(), args.weights_cache_override.as_deref(), running.clone(), + false, // avoid merging lora adapters if using multi-lora )?; } else { return Err(LauncherError::ArgumentValidation(format!(