mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing qwen2.
This commit is contained in:
parent
ee47973a2f
commit
9cec099aa4
@ -4,7 +4,7 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers.models.qwen2 import Qwen2Tokenizer
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||
@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from transformers.models.qwen2 import Qwen2Config
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||
else:
|
||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
||||
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = Qwen2Config.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
Loading…
Reference in New Issue
Block a user