From ddc35e4eb88736378a328158d058a65c550e2bb2 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 29 Nov 2024 11:00:13 -0500 Subject: [PATCH] fix: add merge-lora arg for model id --- launcher/src/main.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fc40bdb1..72bd0ebd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1193,6 +1193,7 @@ fn download_convert_model( huggingface_hub_cache: Option<&str>, weights_cache_override: Option<&str>, running: Arc, + merge_lora: bool, ) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); @@ -1207,6 +1208,10 @@ fn download_convert_model( "--json-output".to_string(), ]; + if merge_lora { + download_args.push("--merge-lora".to_string()); + } + // Model optional revision if let Some(revision) = &revision { download_args.push("--revision".to_string()); @@ -1842,6 +1847,7 @@ fn main() -> Result<(), LauncherError> { args.huggingface_hub_cache.as_deref(), args.weights_cache_override.as_deref(), running.clone(), + true, // if its only a lora model - we should merge the lora adapters )?; // Download and convert lora adapters if any @@ -1875,6 +1881,7 @@ fn main() -> Result<(), LauncherError> { args.huggingface_hub_cache.as_deref(), args.weights_cache_override.as_deref(), running.clone(), + false, // avoid merging lora adapters if using multi-lora )?; } else { return Err(LauncherError::ArgumentValidation(format!(