add peft param to launcher

This commit is contained in:
Chris 2023-08-28 00:05:45 +02:00
parent 8175a305aa
commit 6b5609b413
4 changed files with 22 additions and 1 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@
target target
router/tokenizer.json router/tokenizer.json
*__pycache__* *__pycache__*
CUSTOM_MODELS/**

View File

@ -1,5 +1,12 @@
FROM ghcr.io/ohmytofu-ai/tgi-angry:1.0.3-rc1 FROM ghcr.io/huggingface/text-generation-inference:1.0.2
COPY ./CUSTOM_MODELS/ /mnt/TOFU/HF_MODELS COPY ./CUSTOM_MODELS/ /mnt/TOFU/HF_MODELS
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements.txt && \
pip install ".[bnb, accelerate, quantize]" --no-cache-dir && \
pip install -e . --force-reinstall --upgrade --no-deps
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -42,6 +42,8 @@ python-client-tests:
python-tests: python-server-tests python-client-tests python-tests: python-server-tests python-client-tests
run-llama-peft:
text-generation-launcher --model-id meta-llama/Llama-2-7b-hf --port 8080 --quantize bitsandbytes --peft-model-path /mnt/TOFU/text-generation-inference/CUSTOM_MODELS/asknature-llama2-70b
run-falcon-7b-instruct: run-falcon-7b-instruct:
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080 text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080

View File

@ -100,6 +100,9 @@ struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)] #[clap(default_value = "bigscience/bloom-560m", long, env)]
model_id: String, model_id: String,
/// The Path for a local PEFT model adapter
#[clap(long, env)]
peft_model_path: Option<String>,
/// The actual revision of the model if you're referring to a model /// The actual revision of the model if you're referring to a model
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
#[clap(long, env)] #[clap(long, env)]
@ -339,6 +342,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn shard_manager( fn shard_manager(
model_id: String, model_id: String,
peft_model_path: Option<String>,
revision: Option<String>, revision: Option<String>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
dtype: Option<Dtype>, dtype: Option<Dtype>,
@ -393,6 +397,11 @@ fn shard_manager(
shard_args.push("--sharded".to_string()); shard_args.push("--sharded".to_string());
} }
if let Some(peft_model_path) = peft_model_path {
shard_args.push("--peft-model-path".to_string());
shard_args.push(peft_model_path.to_string())
}
if let Some(quantize) = quantize { if let Some(quantize) = quantize {
shard_args.push("--quantize".to_string()); shard_args.push("--quantize".to_string());
shard_args.push(quantize.to_string()) shard_args.push(quantize.to_string())
@ -838,6 +847,7 @@ fn spawn_shards(
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
let model_id = args.model_id.clone(); let model_id = args.model_id.clone();
let peft_model_path = args.peft_model_path.clone();
let revision = args.revision.clone(); let revision = args.revision.clone();
let uds_path = args.shard_uds_path.clone(); let uds_path = args.shard_uds_path.clone();
let master_addr = args.master_addr.clone(); let master_addr = args.master_addr.clone();
@ -860,6 +870,7 @@ fn spawn_shards(
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
peft_model_path,
revision, revision,
quantize, quantize,
dtype, dtype,