mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
* feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek <datavistics@gmail.com>
69 lines
2.1 KiB
Python
69 lines
2.1 KiB
Python
import os
|
|
from typing import Union
|
|
from loguru import logger
|
|
import torch
|
|
|
|
from transformers import AutoTokenizer
|
|
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
|
|
|
|
|
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
|
torch_dtype = torch.float16
|
|
|
|
logger.info("Trying to load a Peft model. It might take a while without feedback")
|
|
try:
|
|
model = AutoPeftModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
torch_dtype=torch_dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
except Exception:
|
|
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
torch_dtype=torch_dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
logger.info("Peft model detected.")
|
|
logger.info(f"Merging the lora weights.")
|
|
|
|
base_model_id = model.peft_config["default"].base_model_name_or_path
|
|
|
|
model = model.merge_and_unload()
|
|
|
|
os.makedirs(model_id, exist_ok=True)
|
|
cache_dir = model_id
|
|
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
base_model_id, trust_remote_code=trust_remote_code
|
|
)
|
|
model.save_pretrained(cache_dir, safe_serialization=True)
|
|
model.config.save_pretrained(cache_dir)
|
|
tokenizer.save_pretrained(cache_dir)
|
|
|
|
|
|
def download_peft(
|
|
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
|
|
):
|
|
torch_dtype = torch.float16
|
|
try:
|
|
_model = AutoPeftModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
torch_dtype=torch_dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
except Exception:
|
|
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
torch_dtype=torch_dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
logger.info("Peft model downloaded.")
|