From cd8477bcf8bf2c7cbfb8f9463a3760a2495aa1dd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 15 May 2023 10:01:59 +0200 Subject: [PATCH] Loading config *after* checking for model name. --- server/text_generation_server/models/__init__.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9613218d..026ab374 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -99,10 +99,21 @@ def get_model( else: return Galactica(model_id, revision, quantize=quantize) + if model_id.startswith("bigcode/"): + if sharded: + if not FLASH_ATTENTION: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") + ) + return FlashSantacoderSharded(model_id, revision, quantize=quantize) + else: + santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder + return santacoder_cls(model_id, revision, quantize=quantize) + config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type - if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"): + if model_type == "gpt_bigcode" or model_id.startswith("bigcode/"): if sharded: if not FLASH_ATTENTION: raise NotImplementedError(