diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index b5411e22..367d3db0 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -38,7 +38,7 @@ class FlashDbrx(FlashCausalLM): raise NotImplementedError("FlashDBRX is only available on GPU") try: - tokenizer = AutoTokenizer.from_pretrained( + tokenizer = GPT2TokenizerFast.from_pretrained( model_id, revision=revision, padding_side="left", @@ -48,16 +48,27 @@ class FlashDbrx(FlashCausalLM): from_slow=False, ) except: - # FIXME: change back to model id once the tokenizer.json is merged - tokenizer = GPT2TokenizerFast.from_pretrained( - "Xenova/dbrx-instruct-tokenizer", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + use_fast=True, + from_slow=False, + ) + except: + # FIXME: change back to model id once the tokenizer.json is merged + tokenizer = GPT2TokenizerFast.from_pretrained( + "Xenova/dbrx-instruct-tokenizer", + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + use_fast=True, + from_slow=False, + ) config = DbrxConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code