diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index c05e9a1d..3569e77c 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -203,7 +203,7 @@ class OPTSharded(OPT): tensor = tensor.to(device) module._parameters[param_name] = tensor - if name == "decoder.embed_tokens.weight": + if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor def forward(