From a072660bf51022f9e1d59b64efc5954a0e1eee45 Mon Sep 17 00:00:00 2001 From: Dong Shin Date: Mon, 14 Aug 2023 21:20:18 +0900 Subject: [PATCH] fix: LlamaTokenizerFast to AutoTokenizer at flash_llama.py (#619) # What does this PR do? A few tokenizer_config in huggingface use LlamaTokenizer, so I think I would have selected `LlamaTokenizer` before. For a few cases where you're using a llama structure but not a llama tokenizer, why not make it to call the AutoTokenizer in exception handling. In the case of `decapoda-research/llama-7b-hf`, LLamaTokenizer is still being used in config.json, so it should be called through` LlamaTokenizer`. Also, if an exception is thrown by LlamaTokenizer, it will cause `LlamaTokenzierFast` to be called from AutoTokenizer. Fixes # 560 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [x] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --- server/text_generation_server/models/flash_llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 96fb0c26..063aa01e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -2,7 +2,8 @@ import torch import torch.distributed from opentelemetry import trace -from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast +from transformers import AutoConfig, AutoTokenizer +from transformers.models.llama import LlamaTokenizer from typing import Optional from text_generation_server.models import FlashCausalLM @@ -44,7 +45,7 @@ class FlashLlama(FlashCausalLM): trust_remote_code=trust_remote_code, ) except Exception: - tokenizer = LlamaTokenizerFast.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left",