diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 28d332a2..0dfd8078 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -83,9 +83,7 @@ def get_model( if "bigcode" in model_id: if sharded: if not FLASH_ATTENTION: - raise NotImplementedError( - "sharded is not supported for Santacoder when FLASH_ATTENTION=0" - ) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")) return FlashSantacoderSharded(model_id, revision=revision) else: santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 228529cc..508b7746 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -373,7 +373,7 @@ class LlamaMLP(nn.Module): x, approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] - else None, + else "none", ) ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 4ff17619..16fd4091 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -376,7 +376,12 @@ class FlashMLP(nn.Module): self.act = ( ACT2FN[act] if "gelu" not in act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) ) if process_group is None: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index b0b30db6..aba538ea 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -209,7 +209,7 @@ class FlashMQAttention(torch.nn.Module): self.num_heads = self.num_heads // process_group.size() self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) self.c_proj = TensorParallelRowLinear( - hidden_size, hidden_size, process_group=process_group, reduce=True + hidden_size, hidden_size, process_group=process_group, ) def forward( @@ -317,7 +317,6 @@ class MLP(nn.Module): intermediate_size, hidden_size, process_group=process_group, - reduce=False, ) def forward(self, hidden_states): diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 6eff8551..fd9de934 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -64,7 +64,7 @@ class FlashSantacoder(FlashCausalLM): dtype, config.architectures[0].startswith("GPT2") ) - self.model = model.eval().to(device).to(dtype) + self.model = model.eval() super(FlashCausalLM, self).__init__( tokenizer=tokenizer, @@ -176,38 +176,37 @@ class FlashSantacoderSharded(FlashSantacoder): device = torch.device(f"cuda:{self.rank}") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 else: - raise NotImplementedError("FlashSantacoder is only available on GPU") + raise NotImplementedError("FlashSantacoderSharded is only available on GPU") if quantize: - raise NotImplementedError("FlashSantacoder does not support quantization") + raise NotImplementedError("FlashSantacoderSharded does not support quantization") tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left", truncation_side="left" ) config = GPT2Config.from_pretrained( model_id, revision=revision, - trust_remote_code=True, # Needed as the config is not part of Transformers ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - # model = FlashSantacoderForCausalLM(config, self.process_group) - model = FlashSantacoderForCausalLM(config) + model = FlashSantacoderForCausalLM(config, self.process_group) torch.distributed.barrier(group=self.process_group) self.load_weights( model, filenames, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, transpose=config.architectures[0].startswith("GPT2"), ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, @@ -219,67 +218,68 @@ class FlashSantacoderSharded(FlashSantacoder): model, filenames: List[str], device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, transpose: bool, ): for file in filenames: with safe_open(file, framework="pt", device=str(device)) as f: - for name in f.keys(): - slice_ = f.get_slice(name) + for key in f.keys(): + slice_ = f.get_slice(key) - layer_name = ".".join(name.split(".")[:4]) + layer_name = ".".join(key.split(".")[:4]) # Fused qkv - if "q_attn.weight" in name or "kv_attn.weight" in name: - final_name = layer_name + ".c_attn.weight" - elif "q_attn.bias" in name or "kv_attn.bias" in name: - final_name = layer_name + ".c_attn.bias" + if "q_attn.weight" in key or "kv_attn.weight" in key: + final_key = layer_name + ".c_attn.weight" + elif "q_attn.bias" in key or "kv_attn.bias" in key: + final_key = layer_name + ".c_attn.bias" else: - final_name = name + final_key = key - module_name, param_name = final_name.rsplit(".", 1) + module_name, param_name = final_key.rsplit(".", 1) module = model.get_submodule(module_name) - # if isinstance(module, TensorParallelColumnLinear): - # dim = 1 if transpose and "weight" in param_name else 0 - # size = slice_.get_shape()[dim] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] - # elif isinstance(module, TensorParallelRowLinear): - # if param_name == "weight": - # dim = 0 if transpose else 1 - # size = slice_.get_shape()[dim] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] - # else: - # tensor = slice_[:] - # # XXX: Hack for Rowlinear to add the bias only once. - # if rank != 0: - # tensor = torch.zeros_like(tensor) - # elif isinstance(module, TensorParallelEmbedding): - # size = slice_.get_shape()[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] - # elif name == "lm_head.weight" and model.transformer.tp_embeddings: - # size = slice_.get_shape()[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] - # else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) + if isinstance(module, TensorParallelColumnLinear): + dim = 1 if transpose and "weight" in param_name else 0 + size = slice_.get_shape()[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + dim = 0 if transpose else 1 + size = slice_.get_shape()[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif key == "lm_head.weight" and model.transformer.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(key) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) try: current_parameter_tensor = module._parameters[param_name] @@ -288,18 +288,18 @@ class FlashSantacoderSharded(FlashSantacoder): if current_parameter_tensor is not None: if transpose and ( - "c_fc.weight" in name - or "c_proj.weight" in name - or "q_attn.weight" in name - or "kv_attn.weight" in name - or "c_attn.weight" in name + "c_fc.weight" in key + or "c_proj.weight" in key + or "q_attn.weight" in key + or "kv_attn.weight" in key + or "c_attn.weight" in key ): # Tranpose as we use nn.Linear instead of Conv1D tensor = tensor.T if current_parameter_tensor.device == torch.device("meta"): # Init qkv - if "c_attn.weight" in final_name: + if "c_attn.weight" in final_key: module._parameters[param_name] = tensor.new_empty( ( model.transformer.head_size @@ -307,7 +307,7 @@ class FlashSantacoderSharded(FlashSantacoder): tensor.shape[1], ) ) - elif "c_attn.bias" in final_name: + elif "c_attn.bias" in final_key: module._parameters[param_name] = tensor.new_empty( ( model.transformer.head_size @@ -316,63 +316,47 @@ class FlashSantacoderSharded(FlashSantacoder): ) # Copy to correct slice - # if "q_attn" in name: - # size = tensor.shape[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = tensor[start:stop] - # module._parameters[param_name][: tensor.shape[0]] = tensor - # elif "kv_attn.weight" in name: - # module._parameters[param_name][ - # model.transformer.head_size - # * model.transformer.num_heads : - # ] = tensor - # elif "kv_attn.bias" in name: - # module._parameters[param_name][ - # model.transformer.head_size - # * model.transformer.num_heads : - # ] = tensor - # elif "c_attn" in name: - # q_tensor = tensor[: -2 * model.transformer.head_size] - # kv_tensor = tensor[-2 * model.transformer.head_size :] - # from loguru import logger - # - # block_size = q_tensor.shape[0] // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # q_tensor = q_tensor[start:stop] - # logger.error(q_tensor.shape) - # logger.error(kv_tensor.shape) - # module._parameters[param_name][ - # : q_tensor.shape[0] - # ] = q_tensor - # module._parameters[param_name][ - # q_tensor.shape[0] : - # ] = kv_tensor - from loguru import logger - if "q_attn.weight" in name: - logger.error(f"q - {module._parameters[param_name][: tensor.shape[0]].shape} - {tensor.shape}") + if "q_attn" in key: + size = tensor.shape[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = tensor[start:stop] module._parameters[param_name][: tensor.shape[0]] = tensor - elif "q_attn.bias" in name: - module._parameters[param_name][: tensor.shape[0]] = tensor - elif "kv_attn.weight" in name: - logger.error(f"kv - {module._parameters[param_name][model.transformer.head_size * model.transformer.num_heads:].shape} - {tensor.shape}") + elif "kv_attn.weight" in key: module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads: + model.transformer.head_size + * model.transformer.num_heads : ] = tensor - elif "kv_attn.bias" in name: + elif "kv_attn.bias" in key: module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads: + model.transformer.head_size + * model.transformer.num_heads : ] = tensor + elif "c_attn" in key: + # Slice q_tensor by shard + q_tensor = tensor[: -2 * model.transformer.head_size] + block_size = q_tensor.shape[0] // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q_tensor = q_tensor[start:stop] + + module._parameters[param_name][ + : q_tensor.shape[0] + ] = q_tensor + + # Kv tensor is copied for every shard + kv_tensor = tensor[-2 * model.transformer.head_size :] + module._parameters[param_name][ + q_tensor.shape[0] : + ] = kv_tensor else: if current_parameter_tensor.shape != tensor.shape: raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) module._parameters[param_name] = tensor - else: module._buffers[param_name] = tensor torch.cuda.empty_cache()