add lm_head

This commit is contained in:
OlivierDehaene 2023-06-01 11:46:51 +02:00
parent f652788d54
commit 246e8f8250

View File

@ -165,9 +165,21 @@ class FlashSantacoder(FlashCausalLM):
del value
if model.lm_head.weight.device == torch.device("meta"):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model : {uninitialized_parameters}"
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
@ -387,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder):
else:
module._buffers[param_name] = tensor
if model.lm_head.weight.device == torch.device("meta"):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)