From 6b5609b41350569abb5944e340c52312d6007bf5 Mon Sep 17 00:00:00 2001 From: Chris Date: Mon, 28 Aug 2023 00:05:45 +0200 Subject: [PATCH] add peft param to launcher --- .gitignore | 1 + Dockerfile.bake-peft-into-container | 9 ++++++++- Makefile | 2 ++ launcher/src/main.rs | 11 +++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 20c9baee..42edff5c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target router/tokenizer.json *__pycache__* +CUSTOM_MODELS/** diff --git a/Dockerfile.bake-peft-into-container b/Dockerfile.bake-peft-into-container index 820ddea6..d9f7b6ac 100644 --- a/Dockerfile.bake-peft-into-container +++ b/Dockerfile.bake-peft-into-container @@ -1,5 +1,12 @@ -FROM ghcr.io/ohmytofu-ai/tgi-angry:1.0.3-rc1 +FROM ghcr.io/huggingface/text-generation-inference:1.0.2 COPY ./CUSTOM_MODELS/ /mnt/TOFU/HF_MODELS +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements.txt && \ + pip install ".[bnb, accelerate, quantize]" --no-cache-dir && \ + pip install -e . --force-reinstall --upgrade --no-deps ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/Makefile b/Makefile index 7f534c7c..ec59a2cc 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,8 @@ python-client-tests: python-tests: python-server-tests python-client-tests +run-llama-peft: + text-generation-launcher --model-id meta-llama/Llama-2-7b-hf --port 8080 --quantize bitsandbytes --peft-model-path /mnt/TOFU/text-generation-inference/CUSTOM_MODELS/asknature-llama2-70b run-falcon-7b-instruct: text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080 diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 75762712..8ccba3ed 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -100,6 +100,9 @@ struct Args { #[clap(default_value = "bigscience/bloom-560m", long, env)] model_id: String, + /// The Path for a local PEFT model adapter + #[clap(long, env)] + peft_model_path: Option, /// The actual revision of the model if you're referring to a model /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. #[clap(long, env)] @@ -339,6 +342,7 @@ enum ShardStatus { #[allow(clippy::too_many_arguments)] fn shard_manager( model_id: String, + peft_model_path: Option, revision: Option, quantize: Option, dtype: Option, @@ -393,6 +397,11 @@ fn shard_manager( shard_args.push("--sharded".to_string()); } + if let Some(peft_model_path) = peft_model_path { + shard_args.push("--peft-model-path".to_string()); + shard_args.push(peft_model_path.to_string()) + } + if let Some(quantize) = quantize { shard_args.push("--quantize".to_string()); shard_args.push(quantize.to_string()) @@ -838,6 +847,7 @@ fn spawn_shards( // Start shard processes for rank in 0..num_shard { let model_id = args.model_id.clone(); + let peft_model_path = args.peft_model_path.clone(); let revision = args.revision.clone(); let uds_path = args.shard_uds_path.clone(); let master_addr = args.master_addr.clone(); @@ -860,6 +870,7 @@ fn spawn_shards( thread::spawn(move || { shard_manager( model_id, + peft_model_path, revision, quantize, dtype,