Lifting check_unitialized.

This commit is contained in:
Nicolas Patry 2023-05-15 10:38:47 +02:00
parent 73d84c6ee5
commit cc3cdeb156
9 changed files with 18 additions and 76 deletions

View File

@ -238,14 +238,7 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.check_initialized()
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -139,15 +139,7 @@ class FlashLlama(FlashCausalLM):
del value
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.check_initialized()
torch.cuda.empty_cache()
model.post_load_weights(quantize)
@ -315,14 +307,7 @@ class FlashLlamaSharded(FlashLlama):
else:
module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.check_initialized()
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -152,13 +152,5 @@ class FlashNeoXSharded(FlashNeoX):
else:
module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.check_initialized()
model.post_load_weights(quantize)

View File

@ -376,17 +376,7 @@ class FlashSantacoderSharded(FlashSantacoder):
else:
module._buffers[param_name] = tensor
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -365,14 +365,7 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.check_initialized()
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -215,14 +215,7 @@ class GPTNeoxSharded(CausalLM):
else:
module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.check_initialized()
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -99,3 +99,13 @@ class Model(ABC):
return token_text, None, None
else:
return "", offset, token_offset
def check_initialized(self):
uninitialized_parameters = []
for n, p in self.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)

View File

@ -212,14 +212,7 @@ class OPTSharded(OPT):
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.check_initialized()
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -222,14 +222,7 @@ class T5Sharded(Seq2SeqLM):
else:
module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.check_initialized()
def forward(
self,