Loading config *after* checking for model name.

This commit is contained in:
Nicolas Patry 2023-05-15 10:01:59 +02:00
parent bef3458ee8
commit cd8477bcf8

View File

@ -99,10 +99,21 @@ def get_model(
else:
return Galactica(model_id, revision, quantize=quantize)
if model_id.startswith("bigcode/"):
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type
if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"):
if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"):
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(