diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e02be3de6..ec990fdea 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -99,7 +99,7 @@ def get_model( else: return Galactica(model_id, revision, quantize=quantize) - if "bigcode" in model_id: + if model_id.startswith("bigcode/"): if sharded: if not FLASH_ATTENTION: raise NotImplementedError( @@ -113,6 +113,17 @@ def get_model( config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type + if model_type == "gpt_bigcode": + if sharded: + if not FLASH_ATTENTION: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") + ) + return FlashSantacoderSharded(model_id, revision, quantize=quantize) + else: + santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder + return santacoder_cls(model_id, revision, quantize=quantize) + if model_type == "bloom": if sharded: return BLOOMSharded(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index c463ee98f..afe4eba5e 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -376,6 +376,9 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor + + model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + uninitialized_parameters = [] for n, p in model.named_parameters(): if p.data.device == torch.device("meta"):