mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Loading config *after* checking for model name.
This commit is contained in:
parent
bef3458ee8
commit
cd8477bcf8
@ -99,10 +99,21 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
return Galactica(model_id, revision, quantize=quantize)
|
return Galactica(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
|
if model_id.startswith("bigcode/"):
|
||||||
|
if sharded:
|
||||||
|
if not FLASH_ATTENTION:
|
||||||
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||||
|
)
|
||||||
|
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
||||||
|
else:
|
||||||
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||||
|
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
|
||||||
if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"):
|
if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"):
|
||||||
if sharded:
|
if sharded:
|
||||||
if not FLASH_ATTENTION:
|
if not FLASH_ATTENTION:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
Loading…
Reference in New Issue
Block a user