Fixing qwen2.

This commit is contained in:
Nicolas Patry 2024-04-26 16:41:07 +02:00
parent ee47973a2f
commit 9cec099aa4

View File

@ -4,7 +4,7 @@ import torch
import torch.distributed
from opentelemetry import trace
from transformers.models.qwen2 import Qwen2Tokenizer
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE
@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from transformers.models.qwen2 import Qwen2Config
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral):
else:
raise NotImplementedError("FlashQwen2 is only available on GPU")
tokenizer = Qwen2Tokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral):
trust_remote_code=trust_remote_code,
)
config = Qwen2Config.from_pretrained(
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize