[python] fix: Fix embedding mapping for deepspeed chat

This commit is contained in:
hyunwoongko 2023-06-03 12:00:07 +09:00
parent 895c5f1562
commit f6ba71f60f

View File

@ -255,7 +255,7 @@ class BLOOMSharded(BLOOM):
raise ValueError(f"Unexpected quantize `{quantize}`") raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
if name == "word_embeddings.weight": if "word_embeddings.weight" in name:
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward( def forward(