Fixing qwen2.

This commit is contained in:
Nicolas Patry 2024-04-26 16:41:07 +02:00
parent ee47973a2f
commit 9cec099aa4

View File

@ -4,7 +4,7 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers.models.qwen2 import Qwen2Tokenizer from transformers import AutoTokenizer, AutoConfig
from typing import Optional from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE from text_generation_server.models.cache_manager import BLOCK_SIZE
@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
) )
from transformers.models.qwen2 import Qwen2Config
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral):
else: else:
raise NotImplementedError("FlashQwen2 is only available on GPU") raise NotImplementedError("FlashQwen2 is only available on GPU")
tokenizer = Qwen2Tokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = Qwen2Config.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize