mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
allow revision for lora adapters from launcher
This commit is contained in:
parent
d1f257ac56
commit
98dcc9beac
23
Cargo.lock
generated
23
Cargo.lock
generated
@ -3422,14 +3422,14 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex"
|
name = "regex"
|
||||||
version = "1.10.6"
|
version = "1.11.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
|
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
"regex-automata 0.4.7",
|
"regex-automata 0.4.8",
|
||||||
"regex-syntax 0.8.4",
|
"regex-syntax 0.8.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3443,13 +3443,13 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-automata"
|
name = "regex-automata"
|
||||||
version = "0.4.7"
|
version = "0.4.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
"regex-syntax 0.8.4",
|
"regex-syntax 0.8.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3460,9 +3460,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-syntax"
|
name = "regex-syntax"
|
||||||
version = "0.8.4"
|
version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "reqwest"
|
name = "reqwest"
|
||||||
@ -4244,6 +4244,7 @@ dependencies = [
|
|||||||
"nix 0.28.0",
|
"nix 0.28.0",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
|
"regex",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
@ -4533,7 +4534,7 @@ dependencies = [
|
|||||||
"rayon",
|
"rayon",
|
||||||
"rayon-cond",
|
"rayon-cond",
|
||||||
"regex",
|
"regex",
|
||||||
"regex-syntax 0.8.4",
|
"regex-syntax 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"spm_precompiled",
|
"spm_precompiled",
|
||||||
@ -4566,7 +4567,7 @@ dependencies = [
|
|||||||
"rayon",
|
"rayon",
|
||||||
"rayon-cond",
|
"rayon-cond",
|
||||||
"regex",
|
"regex",
|
||||||
"regex-syntax 0.8.4",
|
"regex-syntax 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"spm_precompiled",
|
"spm_precompiled",
|
||||||
|
@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
|||||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To specify model revision, use `adapter_id@revision`, as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
|
||||||
|
```
|
||||||
|
|
||||||
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -18,6 +18,7 @@ serde_json = "1.0.107"
|
|||||||
thiserror = "1.0.59"
|
thiserror = "1.0.59"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
|
regex = "1.11.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
float_eq = "1.0.1"
|
float_eq = "1.0.1"
|
||||||
|
@ -5,6 +5,7 @@ use hf_hub::{
|
|||||||
};
|
};
|
||||||
use nix::sys::signal::{self, Signal};
|
use nix::sys::signal::{self, Signal};
|
||||||
use nix::unistd::Pid;
|
use nix::unistd::Pid;
|
||||||
|
use regex::Regex;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
@ -1808,14 +1809,37 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
if adapter.contains('=') {
|
if adapter.contains('=') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
download_convert_model(
|
|
||||||
adapter,
|
let adapter = adapter.trim();
|
||||||
None,
|
|
||||||
args.trust_remote_code,
|
// check if adapter has more than 1 '@'
|
||||||
args.huggingface_hub_cache.as_deref(),
|
if adapter.matches('@').count() > 1 {
|
||||||
args.weights_cache_override.as_deref(),
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
running.clone(),
|
"Invalid LoRA adapter format: {}",
|
||||||
)?;
|
adapter
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
||||||
|
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
|
||||||
|
if let Some(caps) = re.captures(adapter) {
|
||||||
|
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
|
||||||
|
let revision = caps.get(3).map(|m| m.as_str());
|
||||||
|
|
||||||
|
download_convert_model(
|
||||||
|
adapter_id,
|
||||||
|
revision,
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.huggingface_hub_cache.as_deref(),
|
||||||
|
args.weights_cache_override.as_deref(),
|
||||||
|
running.clone(),
|
||||||
|
)?;
|
||||||
|
} else {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"Invalid LoRA adapter format: {}",
|
||||||
|
adapter
|
||||||
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user