text-generation-inference/server/text_generation_server/utils/peft.py
Diwank Singh Tomer 91111a0dc2
Fix missing trust_remote_code flag for AutoTokenizer in utils.peft (#1270)
Peft loading function was missing the
`trust_remote_code=trust_remote_code` argument causing the custom
tokenizer code to be not found.


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil
2023-11-23 12:41:05 +01:00

45 lines
1.4 KiB
Python

import os
import json
from loguru import logger
import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
logger.info("Peft model detected.")
logger.info("Loading the model it might take a while without feedback")
try:
model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info(f"Loaded.")
logger.info(f"Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path
model = model.merge_and_unload()
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code)
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)