From a788888619bd5e7e3c726cd77e845802e6e23c4e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 26 Apr 2024 19:19:08 +0200 Subject: [PATCH] Fixing qwen2. (#1818) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/models/flash_qwen2.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index c3c63516..cb3cf6b0 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -4,7 +4,7 @@ import torch import torch.distributed from opentelemetry import trace -from transformers.models.qwen2 import Qwen2Tokenizer +from transformers import AutoTokenizer, AutoConfig from typing import Optional from text_generation_server.models.cache_manager import BLOCK_SIZE @@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2ForCausalLM, ) -from transformers.models.qwen2 import Qwen2Config from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral): else: raise NotImplementedError("FlashQwen2 is only available on GPU") - tokenizer = Qwen2Tokenizer.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", @@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral): trust_remote_code=trust_remote_code, ) - config = Qwen2Config.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize