mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add lm_head
This commit is contained in:
parent
f652788d54
commit
246e8f8250
@ -165,9 +165,21 @@ class FlashSantacoder(FlashCausalLM):
|
||||
|
||||
del value
|
||||
|
||||
if model.lm_head.weight.device == torch.device("meta"):
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model : {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
@ -387,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
if model.lm_head.weight.device == torch.device("meta"):
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
Loading…
Reference in New Issue
Block a user