mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
allow revision for lora adapters from launcher
This commit is contained in:
parent
d1f257ac56
commit
98dcc9beac
23
Cargo.lock
generated
23
Cargo.lock
generated
@ -3422,14 +3422,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.6"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
|
||||
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata 0.4.7",
|
||||
"regex-syntax 0.8.4",
|
||||
"regex-automata 0.4.8",
|
||||
"regex-syntax 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3443,13 +3443,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.7"
|
||||
version = "0.4.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
||||
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax 0.8.4",
|
||||
"regex-syntax 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3460,9 +3460,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.4"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
@ -4244,6 +4244,7 @@ dependencies = [
|
||||
"nix 0.28.0",
|
||||
"once_cell",
|
||||
"pyo3",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@ -4533,7 +4534,7 @@ dependencies = [
|
||||
"rayon",
|
||||
"rayon-cond",
|
||||
"regex",
|
||||
"regex-syntax 0.8.4",
|
||||
"regex-syntax 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
@ -4566,7 +4567,7 @@ dependencies = [
|
||||
"rayon",
|
||||
"rayon-cond",
|
||||
"regex",
|
||||
"regex-syntax 0.8.4",
|
||||
"regex-syntax 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
|
@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||
```
|
||||
|
||||
To specify model revision, use `adapter_id@revision`, as follows:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
|
||||
```
|
||||
|
||||
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||
|
||||
```bash
|
||||
|
@ -18,6 +18,7 @@ serde_json = "1.0.107"
|
||||
thiserror = "1.0.59"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
regex = "1.11.0"
|
||||
|
||||
[dev-dependencies]
|
||||
float_eq = "1.0.1"
|
||||
|
@ -5,6 +5,7 @@ use hf_hub::{
|
||||
};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
@ -1808,14 +1809,37 @@ fn main() -> Result<(), LauncherError> {
|
||||
if adapter.contains('=') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let adapter = adapter.trim();
|
||||
|
||||
// check if adapter has more than 1 '@'
|
||||
if adapter.matches('@').count() > 1 {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"Invalid LoRA adapter format: {}",
|
||||
adapter
|
||||
)));
|
||||
}
|
||||
|
||||
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
||||
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
|
||||
if let Some(caps) = re.captures(adapter) {
|
||||
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
|
||||
let revision = caps.get(3).map(|m| m.as_str());
|
||||
|
||||
download_convert_model(
|
||||
adapter,
|
||||
None,
|
||||
adapter_id,
|
||||
revision,
|
||||
args.trust_remote_code,
|
||||
args.huggingface_hub_cache.as_deref(),
|
||||
args.weights_cache_override.as_deref(),
|
||||
running.clone(),
|
||||
)?;
|
||||
} else {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"Invalid LoRA adapter format: {}",
|
||||
adapter
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user