mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
add peft param to launcher
This commit is contained in:
parent
8175a305aa
commit
6b5609b413
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,3 +2,4 @@
|
|||||||
target
|
target
|
||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
CUSTOM_MODELS/**
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
FROM ghcr.io/ohmytofu-ai/tgi-angry:1.0.3-rc1
|
FROM ghcr.io/huggingface/text-generation-inference:1.0.2
|
||||||
COPY ./CUSTOM_MODELS/ /mnt/TOFU/HF_MODELS
|
COPY ./CUSTOM_MODELS/ /mnt/TOFU/HF_MODELS
|
||||||
|
COPY server server
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install -r requirements.txt && \
|
||||||
|
pip install ".[bnb, accelerate, quantize]" --no-cache-dir && \
|
||||||
|
pip install -e . --force-reinstall --upgrade --no-deps
|
||||||
|
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
2
Makefile
2
Makefile
@ -42,6 +42,8 @@ python-client-tests:
|
|||||||
|
|
||||||
python-tests: python-server-tests python-client-tests
|
python-tests: python-server-tests python-client-tests
|
||||||
|
|
||||||
|
run-llama-peft:
|
||||||
|
text-generation-launcher --model-id meta-llama/Llama-2-7b-hf --port 8080 --quantize bitsandbytes --peft-model-path /mnt/TOFU/text-generation-inference/CUSTOM_MODELS/asknature-llama2-70b
|
||||||
run-falcon-7b-instruct:
|
run-falcon-7b-instruct:
|
||||||
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080
|
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080
|
||||||
|
|
||||||
|
@ -100,6 +100,9 @@ struct Args {
|
|||||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
|
||||||
|
/// The Path for a local PEFT model adapter
|
||||||
|
#[clap(long, env)]
|
||||||
|
peft_model_path: Option<String>,
|
||||||
/// The actual revision of the model if you're referring to a model
|
/// The actual revision of the model if you're referring to a model
|
||||||
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -339,6 +342,7 @@ enum ShardStatus {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn shard_manager(
|
fn shard_manager(
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
peft_model_path: Option<String>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
@ -393,6 +397,11 @@ fn shard_manager(
|
|||||||
shard_args.push("--sharded".to_string());
|
shard_args.push("--sharded".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(peft_model_path) = peft_model_path {
|
||||||
|
shard_args.push("--peft-model-path".to_string());
|
||||||
|
shard_args.push(peft_model_path.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(quantize) = quantize {
|
if let Some(quantize) = quantize {
|
||||||
shard_args.push("--quantize".to_string());
|
shard_args.push("--quantize".to_string());
|
||||||
shard_args.push(quantize.to_string())
|
shard_args.push(quantize.to_string())
|
||||||
@ -838,6 +847,7 @@ fn spawn_shards(
|
|||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
let model_id = args.model_id.clone();
|
let model_id = args.model_id.clone();
|
||||||
|
let peft_model_path = args.peft_model_path.clone();
|
||||||
let revision = args.revision.clone();
|
let revision = args.revision.clone();
|
||||||
let uds_path = args.shard_uds_path.clone();
|
let uds_path = args.shard_uds_path.clone();
|
||||||
let master_addr = args.master_addr.clone();
|
let master_addr = args.master_addr.clone();
|
||||||
@ -860,6 +870,7 @@ fn spawn_shards(
|
|||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
|
peft_model_path,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
dtype,
|
dtype,
|
||||||
|
Loading…
Reference in New Issue
Block a user