fix: LlamaTokenizerFast to AutoTokenizer at flash_mistral.py (#1637)

# What does this PR do?

A few cases where you're using a mistral structure or mixtral structure
but not a llama tokenizer, why not make it to call the AutoTokenizer in
exception handling.

Similar PR #619

@Narsil
This commit is contained in:
SeongBeomLEE 2024-03-23 01:13:13 +09:00 committed by Karol Damaszke
parent c4f92ec4b5
commit 097e72a672

View File

@ -6,7 +6,7 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, AutoTokenizer
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type
@ -317,13 +317,22 @@ class BaseFlashMistral(FlashCausalLM):
else:
raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code