mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Lifting check_unitialized.
This commit is contained in:
parent
73d84c6ee5
commit
cc3cdeb156
@ -238,14 +238,7 @@ class BLOOMSharded(BLOOM):
|
||||
if name == "word_embeddings.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
|
@ -139,15 +139,7 @@ class FlashLlama(FlashCausalLM):
|
||||
|
||||
del value
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
model.check_initialized()
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
@ -315,14 +307,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
model.check_initialized()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
@ -152,13 +152,5 @@ class FlashNeoXSharded(FlashNeoX):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
model.check_initialized()
|
||||
model.post_load_weights(quantize)
|
||||
|
@ -376,17 +376,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
@ -365,14 +365,7 @@ class GalacticaSharded(Galactica):
|
||||
if name == "model.decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
|
@ -215,14 +215,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
|
@ -99,3 +99,13 @@ class Model(ABC):
|
||||
return token_text, None, None
|
||||
else:
|
||||
return "", offset, token_offset
|
||||
|
||||
def check_initialized(self):
|
||||
uninitialized_parameters = []
|
||||
for n, p in self.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
|
@ -212,14 +212,7 @@ class OPTSharded(OPT):
|
||||
if name == "model.decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
|
@ -222,14 +222,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
)
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user