From 2378529c1509ec14167d1eb1e28e5d44d37e48a8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 6 Apr 2023 15:03:13 +0200 Subject: [PATCH] wip --- .../models/flash_santacoder.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index b06b9d95..be94c58b 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -62,6 +62,7 @@ class FlashSantacoder(FlashCausalLM): filenames, device, dtype, + config.architectures[0].startswith("GPT2") ) self.model = model.eval().to(device).to(dtype) @@ -76,6 +77,7 @@ class FlashSantacoder(FlashCausalLM): filenames: List[Path], device: torch.device, dtype: torch.dtype, + transpose: bool ): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") @@ -102,7 +104,7 @@ class FlashSantacoder(FlashCausalLM): current_parameter_tensor = None if current_parameter_tensor is not None: - if ( + if transpose and ( "c_fc.weight" in key or "c_proj.weight" in key or "q_attn.weight" in key @@ -202,6 +204,7 @@ class FlashSantacoderSharded(FlashSantacoder): device=device, rank=self.rank, world_size=self.world_size, + transpose=config.architectures[0].startswith("GPT2") ) self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) @@ -217,6 +220,7 @@ class FlashSantacoderSharded(FlashSantacoder): device: torch.device, rank: int, world_size: int, + transpose: bool ): for file in filenames: with safe_open(file, framework="pt", device=str(device)) as f: @@ -262,7 +266,7 @@ class FlashSantacoderSharded(FlashSantacoder): tensor = slice_[start:stop] elif "c_attn" in name: size = slice_.get_shape()[0] - block_size = + raise ValueError elif name == "lm_head.weight" and model.transformer.tp_embeddings: size = slice_.get_shape()[0] block_size = size // world_size @@ -283,6 +287,16 @@ class FlashSantacoderSharded(FlashSantacoder): current_parameter_tensor = None if current_parameter_tensor is not None: + if transpose and ( + "c_fc.weight" in name + or "c_proj.weight" in name + or "q_attn.weight" in name + or "kv_attn.weight" in name + or "c_attn.weight" in name + ): + # Tranpose as we use nn.Linear instead of Conv1D + value = value.T + if current_parameter_tensor.device == torch.device("meta"): # Init qkv if "c_attn.weight" in final_name: