mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
patch safetensors loading
This commit is contained in:
parent
ef51a1e0b7
commit
34931a2111
@ -203,7 +203,7 @@ class OPTSharded(OPT):
|
|||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
Reference in New Issue
Block a user