mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Lifting the call to.
This commit is contained in:
parent
cc3cdeb156
commit
62b4082514
@ -238,8 +238,6 @@ class BLOOMSharded(BLOOM):
|
||||
if name == "word_embeddings.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
|
@ -139,7 +139,6 @@ class FlashLlama(FlashCausalLM):
|
||||
|
||||
del value
|
||||
|
||||
model.check_initialized()
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
@ -307,7 +306,5 @@ class FlashLlamaSharded(FlashLlama):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.check_initialized()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
@ -152,5 +152,4 @@ class FlashNeoXSharded(FlashNeoX):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.check_initialized()
|
||||
model.post_load_weights(quantize)
|
||||
|
@ -377,6 +377,5 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
@ -365,8 +365,6 @@ class GalacticaSharded(Galactica):
|
||||
if name == "model.decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
|
@ -215,8 +215,6 @@ class GPTNeoxSharded(CausalLM):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
|
@ -32,6 +32,7 @@ class Model(ABC):
|
||||
self.decode_buffer = decode_buffer
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.check_initialized()
|
||||
|
||||
@property
|
||||
def info(self) -> InfoResponse:
|
||||
@ -107,5 +108,5 @@ class Model(ABC):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||||
)
|
||||
|
@ -212,8 +212,6 @@ class OPTSharded(OPT):
|
||||
if name == "model.decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
|
@ -222,8 +222,6 @@ class T5Sharded(Seq2SeqLM):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.check_initialized()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user