Fix GPT2 detection.

This commit is contained in:
Nicolas Patry 2024-02-26 11:20:39 +00:00
parent ed95f1982d
commit 680a52f2f2

View File

@ -181,7 +181,11 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in {"gpt_bigcode", "gpt2"}: if (
model_type == "gpt_bigcode"
or model_type == "gpt2"
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashSantacoderSharded( return FlashSantacoderSharded(
model_id, model_id,