diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index 33b053a6..181a93b1 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -3,7 +3,7 @@ import torch.distributed from opentelemetry import trace from typing import Optional -from transformers.models.llama import LlamaTokenizerFast +from transformers import AutoTokenizer from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( @@ -36,7 +36,7 @@ class FlashCohere(FlashCausalLM): else: raise NotImplementedError("FlashCohere is only available on GPU") - tokenizer = LlamaTokenizerFast.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left",