diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5cb66382..9613218d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -102,7 +102,7 @@ def get_model( config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type - if model_type == "gpt_bigcode": + if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"): if sharded: if not FLASH_ATTENTION: raise NotImplementedError(