fix: LlamaTokenizerFast to AutoTokenizer at flash_mistral.py

This commit is contained in:
SeongBeomLEE 2024-03-11 13:27:09 +09:00
parent 7dbaf9e901
commit 2111ae1bd2

View File

@ -6,7 +6,7 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, AutoTokenizer
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type
@ -317,13 +317,22 @@ class BaseFlashMistral(FlashCausalLM):
else:
raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code