From cc3cdeb1564ac43d9c6edebc282723129871658e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 15 May 2023 10:38:47 +0200 Subject: [PATCH] Lifting check_unitialized. --- server/text_generation_server/models/bloom.py | 9 +-------- .../models/flash_llama.py | 19 ++----------------- .../models/flash_neox.py | 10 +--------- .../models/flash_santacoder.py | 10 ---------- .../models/galactica.py | 9 +-------- .../text_generation_server/models/gpt_neox.py | 9 +-------- server/text_generation_server/models/model.py | 10 ++++++++++ server/text_generation_server/models/opt.py | 9 +-------- server/text_generation_server/models/t5.py | 9 +-------- 9 files changed, 18 insertions(+), 76 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 877acb00..55e97613 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -238,14 +238,7 @@ class BLOOMSharded(BLOOM): if name == "word_embeddings.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) + model.check_initialized() def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 5f47cf66..d3f48e2b 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -139,15 +139,7 @@ class FlashLlama(FlashCausalLM): del value - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - + model.check_initialized() torch.cuda.empty_cache() model.post_load_weights(quantize) @@ -315,14 +307,7 @@ class FlashLlamaSharded(FlashLlama): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) + model.check_initialized() torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index b3e1876f..6927d472 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -152,13 +152,5 @@ class FlashNeoXSharded(FlashNeoX): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - + model.check_initialized() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index afe4eba5..884746a5 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -376,17 +376,7 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 4f94b348..23972e89 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -365,14 +365,7 @@ class GalacticaSharded(Galactica): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) + model.check_initialized() def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 2d42e0b0..b07d4c2a 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -215,14 +215,7 @@ class GPTNeoxSharded(CausalLM): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) + model.check_initialized() def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 4c85c952..f9c64b2c 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -99,3 +99,13 @@ class Model(ABC): return token_text, None, None else: return "", offset, token_offset + + def check_initialized(self): + uninitialized_parameters = [] + for n, p in self.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 44f15df3..91bf3715 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -212,14 +212,7 @@ class OPTSharded(OPT): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) + model.check_initialized() def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 381617b7..00c958c0 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -222,14 +222,7 @@ class T5Sharded(Seq2SeqLM): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) + model.check_initialized() def forward( self,