diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 2a747f50..7907e2cc 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -165,9 +165,21 @@ class FlashSantacoder(FlashCausalLM): del value + if model.lm_head.weight.device == torch.device("meta"): + model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + torch.cuda.empty_cache() model.post_load_weights(quantize) + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model : {uninitialized_parameters}" + ) + def decode(self, generated_ids: List[int]) -> str: # Do not skip special tokens as they are used for custom parsing rules of the generated text return self.tokenizer.decode( @@ -387,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + if model.lm_head.weight.device == torch.device("meta"): + model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + torch.cuda.empty_cache() model.post_load_weights(quantize)