Lifting the call to.

This commit is contained in:
Nicolas Patry 2023-05-15 10:38:08 +02:00
parent cc3cdeb156
commit 62b4082514
9 changed files with 2 additions and 16 deletions

View File

@ -238,8 +238,6 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
model.check_initialized()
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):

View File

@ -139,7 +139,6 @@ class FlashLlama(FlashCausalLM):
del value
model.check_initialized()
torch.cuda.empty_cache()
model.post_load_weights(quantize)
@ -307,7 +306,5 @@ class FlashLlamaSharded(FlashLlama):
else:
module._buffers[param_name] = tensor
model.check_initialized()
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -152,5 +152,4 @@ class FlashNeoXSharded(FlashNeoX):
else:
module._buffers[param_name] = tensor
model.check_initialized()
model.post_load_weights(quantize)

View File

@ -377,6 +377,5 @@ class FlashSantacoderSharded(FlashSantacoder):
module._buffers[param_name] = tensor
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -365,8 +365,6 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
model.check_initialized()
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):

View File

@ -215,8 +215,6 @@ class GPTNeoxSharded(CausalLM):
else:
module._buffers[param_name] = tensor
model.check_initialized()
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):

View File

@ -32,6 +32,7 @@ class Model(ABC):
self.decode_buffer = decode_buffer
self.rank = rank
self.world_size = world_size
self.check_initialized()
@property
def info(self) -> InfoResponse:
@ -107,5 +108,5 @@ class Model(ABC):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
)

View File

@ -212,8 +212,6 @@ class OPTSharded(OPT):
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
model.check_initialized()
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):

View File

@ -222,8 +222,6 @@ class T5Sharded(Seq2SeqLM):
else:
module._buffers[param_name] = tensor
model.check_initialized()
def forward(
self,
input_ids,