From 62b4082514c37b42b4396bb7af55808ebb4aa058 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 15 May 2023 10:38:08 +0200 Subject: [PATCH] Lifting the call to. --- server/text_generation_server/models/bloom.py | 2 -- server/text_generation_server/models/flash_llama.py | 3 --- server/text_generation_server/models/flash_neox.py | 1 - server/text_generation_server/models/flash_santacoder.py | 1 - server/text_generation_server/models/galactica.py | 2 -- server/text_generation_server/models/gpt_neox.py | 2 -- server/text_generation_server/models/model.py | 3 ++- server/text_generation_server/models/opt.py | 2 -- server/text_generation_server/models/t5.py | 2 -- 9 files changed, 2 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 55e97613..f6a69031 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -238,8 +238,6 @@ class BLOOMSharded(BLOOM): if name == "word_embeddings.weight": model.lm_head._parameters["weight"] = tensor - model.check_initialized() - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index d3f48e2b..0b63f904 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -139,7 +139,6 @@ class FlashLlama(FlashCausalLM): del value - model.check_initialized() torch.cuda.empty_cache() model.post_load_weights(quantize) @@ -307,7 +306,5 @@ class FlashLlamaSharded(FlashLlama): else: module._buffers[param_name] = tensor - model.check_initialized() - torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 6927d472..168c9195 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -152,5 +152,4 @@ class FlashNeoXSharded(FlashNeoX): else: module._buffers[param_name] = tensor - model.check_initialized() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 884746a5..51a8998b 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -377,6 +377,5 @@ class FlashSantacoderSharded(FlashSantacoder): module._buffers[param_name] = tensor model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) - torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 23972e89..c6dd4c33 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -365,8 +365,6 @@ class GalacticaSharded(Galactica): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - model.check_initialized() - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index b07d4c2a..215bb2b6 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -215,8 +215,6 @@ class GPTNeoxSharded(CausalLM): else: module._buffers[param_name] = tensor - model.check_initialized() - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index f9c64b2c..31bea6ca 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -32,6 +32,7 @@ class Model(ABC): self.decode_buffer = decode_buffer self.rank = rank self.world_size = world_size + self.check_initialized() @property def info(self) -> InfoResponse: @@ -107,5 +108,5 @@ class Model(ABC): uninitialized_parameters.append(n) if uninitialized_parameters: raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" + f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 91bf3715..8d856b10 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -212,8 +212,6 @@ class OPTSharded(OPT): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - model.check_initialized() - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 00c958c0..b5e7710d 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -222,8 +222,6 @@ class T5Sharded(Seq2SeqLM): else: module._buffers[param_name] = tensor - model.check_initialized() - def forward( self, input_ids,