mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Loading config *after* checking for model name.
This commit is contained in:
parent
bef3458ee8
commit
cd8477bcf8
@ -99,10 +99,21 @@ def get_model(
|
||||
else:
|
||||
return Galactica(model_id, revision, quantize=quantize)
|
||||
|
||||
if model_id.startswith("bigcode/"):
|
||||
if sharded:
|
||||
if not FLASH_ATTENTION:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||
)
|
||||
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
||||
else:
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
model_type = config.model_type
|
||||
|
||||
if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"):
|
||||
if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"):
|
||||
if sharded:
|
||||
if not FLASH_ATTENTION:
|
||||
raise NotImplementedError(
|
||||
|
Loading…
Reference in New Issue
Block a user