mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Loading config *after* checking for model name.
This commit is contained in:
parent
bef3458ee8
commit
cd8477bcf8
@ -99,6 +99,17 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
return Galactica(model_id, revision, quantize=quantize)
|
return Galactica(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
|
if model_id.startswith("bigcode/"):
|
||||||
|
if sharded:
|
||||||
|
if not FLASH_ATTENTION:
|
||||||
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||||
|
)
|
||||||
|
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
||||||
|
else:
|
||||||
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||||
|
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user