From 98dcc9beac3a100dda044a38499e7781db4438d9 Mon Sep 17 00:00:00 2001 From: Sida Date: Mon, 30 Sep 2024 21:57:07 -0400 Subject: [PATCH] allow revision for lora adapters from launcher --- Cargo.lock | 23 +++++++++---------- docs/source/conceptual/lora.md | 6 +++++ launcher/Cargo.toml | 1 + launcher/src/main.rs | 40 +++++++++++++++++++++++++++------- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6796212f..bf0fa17d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3422,14 +3422,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.8", + "regex-syntax 0.8.5", ] [[package]] @@ -3443,13 +3443,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", ] [[package]] @@ -3460,9 +3460,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" @@ -4244,6 +4244,7 @@ dependencies = [ "nix 0.28.0", "once_cell", "pyo3", + "regex", "reqwest", "serde", "serde_json", @@ -4533,7 +4534,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", "serde", "serde_json", "spm_precompiled", @@ -4566,7 +4567,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", "serde", "serde_json", "spm_precompiled", diff --git a/docs/source/conceptual/lora.md b/docs/source/conceptual/lora.md index 0b7e3616..d1f4ce78 100644 --- a/docs/source/conceptual/lora.md +++ b/docs/source/conceptual/lora.md @@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia ``` +To specify model revision, use `adapter_id@revision`, as follows: + +```bash +LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2 +``` + To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"` ```bash diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 033a9a04..fdc3c02c 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -18,6 +18,7 @@ serde_json = "1.0.107" thiserror = "1.0.59" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +regex = "1.11.0" [dev-dependencies] float_eq = "1.0.1" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 474a72d3..aba497d6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -5,6 +5,7 @@ use hf_hub::{ }; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; +use regex::Regex; use serde::Deserialize; use std::env; use std::ffi::OsString; @@ -1808,14 +1809,37 @@ fn main() -> Result<(), LauncherError> { if adapter.contains('=') { continue; } - download_convert_model( - adapter, - None, - args.trust_remote_code, - args.huggingface_hub_cache.as_deref(), - args.weights_cache_override.as_deref(), - running.clone(), - )?; + + let adapter = adapter.trim(); + + // check if adapter has more than 1 '@' + if adapter.matches('@').count() > 1 { + return Err(LauncherError::ArgumentValidation(format!( + "Invalid LoRA adapter format: {}", + adapter + ))); + } + + // capture adapter_id, path, revision in format of adapter_id=path@revision + let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap(); + if let Some(caps) = re.captures(adapter) { + let adapter_id = caps.get(1).map_or("", |m| m.as_str()); + let revision = caps.get(3).map(|m| m.as_str()); + + download_convert_model( + adapter_id, + revision, + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + )?; + } else { + return Err(LauncherError::ArgumentValidation(format!( + "Invalid LoRA adapter format: {}", + adapter + ))); + } } }