mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Lifting check_unitialized.
This commit is contained in:
parent
73d84c6ee5
commit
cc3cdeb156
@ -238,14 +238,7 @@ class BLOOMSharded(BLOOM):
|
|||||||
if name == "word_embeddings.weight":
|
if name == "word_embeddings.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
uninitialized_parameters = []
|
model.check_initialized()
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
@ -139,15 +139,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
|
|
||||||
del value
|
del value
|
||||||
|
|
||||||
uninitialized_parameters = []
|
model.check_initialized()
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights(quantize)
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
@ -315,14 +307,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
uninitialized_parameters = []
|
model.check_initialized()
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights(quantize)
|
model.post_load_weights(quantize)
|
||||||
|
@ -152,13 +152,5 @@ class FlashNeoXSharded(FlashNeoX):
|
|||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
uninitialized_parameters = []
|
model.check_initialized()
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
model.post_load_weights(quantize)
|
model.post_load_weights(quantize)
|
||||||
|
@ -376,17 +376,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
|
|
||||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||||
|
|
||||||
uninitialized_parameters = []
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights(quantize)
|
model.post_load_weights(quantize)
|
||||||
|
@ -365,14 +365,7 @@ class GalacticaSharded(Galactica):
|
|||||||
if name == "model.decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
uninitialized_parameters = []
|
model.check_initialized()
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
@ -215,14 +215,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
uninitialized_parameters = []
|
model.check_initialized()
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
@ -99,3 +99,13 @@ class Model(ABC):
|
|||||||
return token_text, None, None
|
return token_text, None, None
|
||||||
else:
|
else:
|
||||||
return "", offset, token_offset
|
return "", offset, token_offset
|
||||||
|
|
||||||
|
def check_initialized(self):
|
||||||
|
uninitialized_parameters = []
|
||||||
|
for n, p in self.named_parameters():
|
||||||
|
if p.data.device == torch.device("meta"):
|
||||||
|
uninitialized_parameters.append(n)
|
||||||
|
if uninitialized_parameters:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||||
|
)
|
||||||
|
@ -212,14 +212,7 @@ class OPTSharded(OPT):
|
|||||||
if name == "model.decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
uninitialized_parameters = []
|
model.check_initialized()
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
@ -222,14 +222,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
uninitialized_parameters = []
|
model.check_initialized()
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if p.data.device == torch.device("meta"):
|
|
||||||
uninitialized_parameters.append(n)
|
|
||||||
if uninitialized_parameters:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user