mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
add peft param to launcher
This commit is contained in:
parent
8175a305aa
commit
6b5609b413
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,3 +2,4 @@
|
||||
target
|
||||
router/tokenizer.json
|
||||
*__pycache__*
|
||||
CUSTOM_MODELS/**
|
||||
|
@ -1,5 +1,12 @@
|
||||
FROM ghcr.io/ohmytofu-ai/tgi-angry:1.0.3-rc1
|
||||
FROM ghcr.io/huggingface/text-generation-inference:1.0.2
|
||||
COPY ./CUSTOM_MODELS/ /mnt/TOFU/HF_MODELS
|
||||
COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install ".[bnb, accelerate, quantize]" --no-cache-dir && \
|
||||
pip install -e . --force-reinstall --upgrade --no-deps
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
2
Makefile
2
Makefile
@ -42,6 +42,8 @@ python-client-tests:
|
||||
|
||||
python-tests: python-server-tests python-client-tests
|
||||
|
||||
run-llama-peft:
|
||||
text-generation-launcher --model-id meta-llama/Llama-2-7b-hf --port 8080 --quantize bitsandbytes --peft-model-path /mnt/TOFU/text-generation-inference/CUSTOM_MODELS/asknature-llama2-70b
|
||||
run-falcon-7b-instruct:
|
||||
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080
|
||||
|
||||
|
@ -100,6 +100,9 @@ struct Args {
|
||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||
model_id: String,
|
||||
|
||||
/// The Path for a local PEFT model adapter
|
||||
#[clap(long, env)]
|
||||
peft_model_path: Option<String>,
|
||||
/// The actual revision of the model if you're referring to a model
|
||||
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
||||
#[clap(long, env)]
|
||||
@ -339,6 +342,7 @@ enum ShardStatus {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn shard_manager(
|
||||
model_id: String,
|
||||
peft_model_path: Option<String>,
|
||||
revision: Option<String>,
|
||||
quantize: Option<Quantization>,
|
||||
dtype: Option<Dtype>,
|
||||
@ -393,6 +397,11 @@ fn shard_manager(
|
||||
shard_args.push("--sharded".to_string());
|
||||
}
|
||||
|
||||
if let Some(peft_model_path) = peft_model_path {
|
||||
shard_args.push("--peft-model-path".to_string());
|
||||
shard_args.push(peft_model_path.to_string())
|
||||
}
|
||||
|
||||
if let Some(quantize) = quantize {
|
||||
shard_args.push("--quantize".to_string());
|
||||
shard_args.push(quantize.to_string())
|
||||
@ -838,6 +847,7 @@ fn spawn_shards(
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_id = args.model_id.clone();
|
||||
let peft_model_path = args.peft_model_path.clone();
|
||||
let revision = args.revision.clone();
|
||||
let uds_path = args.shard_uds_path.clone();
|
||||
let master_addr = args.master_addr.clone();
|
||||
@ -860,6 +870,7 @@ fn spawn_shards(
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
peft_model_path,
|
||||
revision,
|
||||
quantize,
|
||||
dtype,
|
||||
|
Loading…
Reference in New Issue
Block a user