mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add lm_head
This commit is contained in:
parent
f652788d54
commit
246e8f8250
@ -165,9 +165,21 @@ class FlashSantacoder(FlashCausalLM):
|
|||||||
|
|
||||||
del value
|
del value
|
||||||
|
|
||||||
|
if model.lm_head.weight.device == torch.device("meta"):
|
||||||
|
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights(quantize)
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
|
uninitialized_parameters = []
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if p.data.device == torch.device("meta"):
|
||||||
|
uninitialized_parameters.append(n)
|
||||||
|
if uninitialized_parameters:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"found uninitialized parameters in model : {uninitialized_parameters}"
|
||||||
|
)
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
@ -387,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
|
if model.lm_head.weight.device == torch.device("meta"):
|
||||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights(quantize)
|
model.post_load_weights(quantize)
|
||||||
|
Loading…
Reference in New Issue
Block a user