diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0cd8d133..28d332a2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,7 +18,6 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded - from text_generation_server.models.flash_santacoder import FlashSantacoder from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded from text_generation_server.models.flash_santacoder import FlashSantacoder, FlashSantacoderSharded @@ -84,7 +83,9 @@ def get_model( if "bigcode" in model_id: if sharded: if not FLASH_ATTENTION: - raise NotImplementedError("sharded is not supported for Santacoder when FLASH_ATTENTION=0") + raise NotImplementedError( + "sharded is not supported for Santacoder when FLASH_ATTENTION=0" + ) return FlashSantacoderSharded(model_id, revision=revision) else: santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 576b86e7..b0b30db6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -69,13 +69,13 @@ class FastLinear(nn.Linear): class TensorParallelColumnLinear(FastLinear): def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() @@ -93,14 +93,14 @@ class TensorParallelColumnLinear(FastLinear): class TensorParallelRowLinear(FastLinear): def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() @@ -126,19 +126,19 @@ class TensorParallelRowLinear(FastLinear): class TensorParallelEmbedding(nn.Embedding): def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - reduce=True, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + reduce=True, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, ): self.process_group = process_group self.tp_rank = process_group.rank() @@ -207,11 +207,7 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = FastLinear(hidden_size, hidden_size) else: self.num_heads = self.num_heads // process_group.size() - self.hidden_size = self.hidden_size // process_group.size() - self.c_attn = FastLinear( - hidden_size, - self.head_size * (self.num_heads + 2) - ) + self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) self.c_proj = TensorParallelRowLinear( hidden_size, hidden_size, process_group=process_group, reduce=True ) @@ -228,7 +224,9 @@ class FlashMQAttention(torch.nn.Module): qkv = self.c_attn(hidden_states) # Split query from key_value - query, key_value = qkv.split([self.head_size * self.num_heads, 2 * self.head_size], dim=1) + query, key_value = qkv.split( + [self.head_size * self.num_heads, 2 * self.head_size], dim=1 + ) # Prepare query and key_value for indexing query = query.view(-1, self.num_heads, self.head_size) @@ -302,7 +300,7 @@ class MLP(nn.Module): x, approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] - else None, + else "none", ) ) @@ -399,11 +397,13 @@ class FlashSantacoderModel(nn.Module): self.wte = TensorParallelEmbedding( config.vocab_size, config.hidden_size, + reduce=False, process_group=process_group, ) self.wpe = TensorParallelEmbedding( config.max_position_embeddings, config.hidden_size, + reduce=False, process_group=process_group, ) else: diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index be94c58b..6eff8551 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -195,7 +195,8 @@ class FlashSantacoderSharded(FlashSantacoder): filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - model = FlashSantacoderForCausalLM(config, self.process_group) + # model = FlashSantacoderForCausalLM(config, self.process_group) + model = FlashSantacoderForCausalLM(config) torch.distributed.barrier(group=self.process_group) self.load_weights( @@ -204,7 +205,7 @@ class FlashSantacoderSharded(FlashSantacoder): device=device, rank=self.rank, world_size=self.world_size, - transpose=config.architectures[0].startswith("GPT2") + transpose=config.architectures[0].startswith("GPT2"), ) self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) @@ -220,7 +221,7 @@ class FlashSantacoderSharded(FlashSantacoder): device: torch.device, rank: int, world_size: int, - transpose: bool + transpose: bool, ): for file in filenames: with safe_open(file, framework="pt", device=str(device)) as f: @@ -240,44 +241,43 @@ class FlashSantacoderSharded(FlashSantacoder): module_name, param_name = final_name.rsplit(".", 1) module = model.get_submodule(module_name) - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif "c_attn" in name: - size = slice_.get_shape()[0] - raise ValueError - elif name == "lm_head.weight" and model.transformer.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) + # if isinstance(module, TensorParallelColumnLinear): + # dim = 1 if transpose and "weight" in param_name else 0 + # size = slice_.get_shape()[dim] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] + # elif isinstance(module, TensorParallelRowLinear): + # if param_name == "weight": + # dim = 0 if transpose else 1 + # size = slice_.get_shape()[dim] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] + # else: + # tensor = slice_[:] + # # XXX: Hack for Rowlinear to add the bias only once. + # if rank != 0: + # tensor = torch.zeros_like(tensor) + # elif isinstance(module, TensorParallelEmbedding): + # size = slice_.get_shape()[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] + # elif name == "lm_head.weight" and model.transformer.tp_embeddings: + # size = slice_.get_shape()[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] + # else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) tensor = tensor.contiguous() @@ -295,7 +295,7 @@ class FlashSantacoderSharded(FlashSantacoder): or "c_attn.weight" in name ): # Tranpose as we use nn.Linear instead of Conv1D - value = value.T + tensor = tensor.T if current_parameter_tensor.device == torch.device("meta"): # Init qkv @@ -316,19 +316,54 @@ class FlashSantacoderSharded(FlashSantacoder): ) # Copy to correct slice + # if "q_attn" in name: + # size = tensor.shape[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = tensor[start:stop] + # module._parameters[param_name][: tensor.shape[0]] = tensor + # elif "kv_attn.weight" in name: + # module._parameters[param_name][ + # model.transformer.head_size + # * model.transformer.num_heads : + # ] = tensor + # elif "kv_attn.bias" in name: + # module._parameters[param_name][ + # model.transformer.head_size + # * model.transformer.num_heads : + # ] = tensor + # elif "c_attn" in name: + # q_tensor = tensor[: -2 * model.transformer.head_size] + # kv_tensor = tensor[-2 * model.transformer.head_size :] + # from loguru import logger + # + # block_size = q_tensor.shape[0] // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # q_tensor = q_tensor[start:stop] + # logger.error(q_tensor.shape) + # logger.error(kv_tensor.shape) + # module._parameters[param_name][ + # : q_tensor.shape[0] + # ] = q_tensor + # module._parameters[param_name][ + # q_tensor.shape[0] : + # ] = kv_tensor + from loguru import logger if "q_attn.weight" in name: + logger.error(f"q - {module._parameters[param_name][: tensor.shape[0]].shape} - {tensor.shape}") module._parameters[param_name][: tensor.shape[0]] = tensor elif "q_attn.bias" in name: module._parameters[param_name][: tensor.shape[0]] = tensor elif "kv_attn.weight" in name: + logger.error(f"kv - {module._parameters[param_name][model.transformer.head_size * model.transformer.num_heads:].shape} - {tensor.shape}") module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : + model.transformer.head_size * model.transformer.num_heads: ] = tensor elif "kv_attn.bias" in name: module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : + model.transformer.head_size * model.transformer.num_heads: ] = tensor else: if current_parameter_tensor.shape != tensor.shape: