patch safetensors loading

This commit is contained in:
OlivierDehaene 2023-02-28 15:56:35 +01:00
parent ef51a1e0b7
commit 34931a2111

View File

@ -203,7 +203,7 @@ class OPTSharded(OPT):
tensor = tensor.to(device)
module._parameters[param_name] = tensor
if name == "decoder.embed_tokens.weight":
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(