Hotfixes for santacoder/bigcode.

This commit is contained in:
Ubuntu 2023-05-08 09:45:27 +00:00
parent b4aa87db58
commit 7e11c5d92b
2 changed files with 7 additions and 4 deletions

View File

@ -99,7 +99,10 @@ def get_model(
else:
return Galactica(model_id, revision, quantize=quantize)
if "bigcode" in model_id:
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type
if model_type == "gpt_bigcode":
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(
@ -110,9 +113,6 @@ def get_model(
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type
if model_type == "bloom":
if sharded:
return BLOOMSharded(model_id, revision, quantize=quantize)

View File

@ -373,6 +373,9 @@ class FlashSantacoderSharded(FlashSantacoder):
else:
module._buffers[param_name] = tensor
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):