Hotfixes for santacoder/bigcode.

This commit is contained in:
Ubuntu 2023-05-08 09:45:27 +00:00
parent b4aa87db58
commit 7e11c5d92b
2 changed files with 7 additions and 4 deletions

View File

@ -99,7 +99,10 @@ def get_model(
else: else:
return Galactica(model_id, revision, quantize=quantize) return Galactica(model_id, revision, quantize=quantize)
if "bigcode" in model_id: config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type
if model_type == "gpt_bigcode":
if sharded: if sharded:
if not FLASH_ATTENTION: if not FLASH_ATTENTION:
raise NotImplementedError( raise NotImplementedError(
@ -110,9 +113,6 @@ def get_model(
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize) return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type
if model_type == "bloom": if model_type == "bloom":
if sharded: if sharded:
return BLOOMSharded(model_id, revision, quantize=quantize) return BLOOMSharded(model_id, revision, quantize=quantize)

View File

@ -373,6 +373,9 @@ class FlashSantacoderSharded(FlashSantacoder):
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
uninitialized_parameters = [] uninitialized_parameters = []
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if p.data.device == torch.device("meta"): if p.data.device == torch.device("meta"):