diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 8edf0677..22cd0f57 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -181,7 +181,11 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type in {"gpt_bigcode", "gpt2"}: + if ( + model_type == "gpt_bigcode" + or model_type == "gpt2" + and model_id.startswith("bigcode/") + ): if FLASH_ATTENTION: return FlashSantacoderSharded( model_id,