allow revision for lora adapters from launcher

This commit is contained in:
Sida 2024-09-30 21:57:07 -04:00
parent d1f257ac56
commit 98dcc9beac
4 changed files with 51 additions and 19 deletions

23
Cargo.lock generated
View File

@ -3422,14 +3422,14 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.6"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.4.7",
"regex-syntax 0.8.4",
"regex-automata 0.4.8",
"regex-syntax 0.8.5",
]
[[package]]
@ -3443,13 +3443,13 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.7"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
]
[[package]]
@ -3460,9 +3460,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.4"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
@ -4244,6 +4244,7 @@ dependencies = [
"nix 0.28.0",
"once_cell",
"pyo3",
"regex",
"reqwest",
"serde",
"serde_json",
@ -4533,7 +4534,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
@ -4566,7 +4567,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",

View File

@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```
To specify model revision, use `adapter_id@revision`, as follows:
```bash
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
```
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
```bash

View File

@ -18,6 +18,7 @@ serde_json = "1.0.107"
thiserror = "1.0.59"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
regex = "1.11.0"
[dev-dependencies]
float_eq = "1.0.1"

View File

@ -5,6 +5,7 @@ use hf_hub::{
};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use regex::Regex;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
@ -1808,14 +1809,37 @@ fn main() -> Result<(), LauncherError> {
if adapter.contains('=') {
continue;
}
download_convert_model(
adapter,
None,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
let adapter = adapter.trim();
// check if adapter has more than 1 '@'
if adapter.matches('@').count() > 1 {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
// capture adapter_id, path, revision in format of adapter_id=path@revision
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
if let Some(caps) = re.captures(adapter) {
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
let revision = caps.get(3).map(|m| m.as_str());
download_convert_model(
adapter_id,
revision,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
} else {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
}
}