From cd5d0a96bafe13a854c4cb605e4040327d5480bf Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 28 Mar 2023 16:12:05 +0200 Subject: [PATCH] feat(server): add flash attention llama --- .../text_generation_server/models/__init__.py | 8 +- .../custom_modeling/flash_llama_modeling.py | 146 ++++++------- .../models/flash_llama.py | 202 ++++++++++-------- 3 files changed, 190 insertions(+), 166 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 577f94b8..bc802df9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -19,7 +19,7 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_santacoder import FlashSantacoder - from text_generation_server.models.flash_llama import FlashLlama + from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded FLASH_ATTENTION = ( torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1 @@ -95,7 +95,11 @@ def get_model( if model_type == "llama": if sharded: - raise NotImplementedError + if FLASH_ATTENTION: + return FlashLlamaSharded(model_id, revision, quantize=quantize) + raise NotImplementedError( + "sharded is not supported for llama when FLASH_ATTENTION=0" + ) else: llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM return llama_cls(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 38d9fa24..5e5610ff 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -1,3 +1,23 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.distributed @@ -23,15 +43,45 @@ class LlamaRMSNorm(nn.Module): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 6144: + if residual is not None: + hidden_states += residual + residual = hidden_states - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) - return self.weight * hidden_states + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + else: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res class FastLinear(nn.Linear): @@ -183,8 +233,6 @@ class PositionRotaryEmbedding(RotaryEmbedding): ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) @@ -245,16 +293,6 @@ class FlashLlamaAttention(torch.nn.Module): process_group=process_group, ) - def shuffle_qkv_dims(self): - """Swap dims to avoid an additional permute""" - self.query_key_value.weight = torch.nn.Parameter( - self.query_key_value.weight.view( - self.num_heads, 3, self.head_size, self.hidden_size - ) - .permute(1, 0, 2, 3) - .reshape(-1, self.hidden_size) - ) - def forward( self, hidden_states, @@ -333,7 +371,6 @@ class LlamaMLP(nn.Module): if "gelu" not in act else lambda x: torch.nn.functional.gelu(x, approximate="tanh") ) - self.intermediate_size = intermediate_size if process_group is None: # Fuse gate and up proj @@ -341,6 +378,7 @@ class LlamaMLP(nn.Module): hidden_size, 2 * intermediate_size, bias=False ) self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False) + self.intermediate_size = intermediate_size else: # Fuse gate and up proj self.gate_up_proj = TensorParallelColumnLinear( @@ -356,6 +394,8 @@ class LlamaMLP(nn.Module): process_group=process_group, reduce=True, ) + self.intermediate_size = self.down_proj.in_features + self.process_group = process_group def forward(self, hidden_states): @@ -394,27 +434,11 @@ class FlashLlamaLayer(nn.Module): layer_past_present_indices, cu_seqlens_q, ): - # faster input rms norm - hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.input_layernorm.weight, - None, - None, - None, - None, - None, - 0.0, - self.input_layernorm.variance_epsilon, - 1.0, - 0, - None, - False, - True, - ) + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - hidden_states, + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, cos, sin, cu_seqlens, @@ -425,27 +449,13 @@ class FlashLlamaLayer(nn.Module): ) # faster post attention rms norm - hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.post_attention_layernorm.weight, - None, - None, - None, - None, - None, - 0.0, - self.post_attention_layernorm.variance_epsilon, - 1.0, - 0, - None, - False, - True, + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res ) - mlp_output = self.mlp(hidden_states) + mlp_output = self.mlp(normed_attn_res_output) - return mlp_output, residual + return mlp_output, attn_res class FlashLlamaModel(torch.nn.Module): @@ -492,7 +502,6 @@ class FlashLlamaModel(torch.nn.Module): self.embed_tokens.add_null_idx() for layer in self.layers: layer: FlashLlamaLayer - layer.self_attn.shuffle_qkv_dims() layer.self_attn.query_key_value.transpose_weight() layer.self_attn.o_proj.transpose_weight() layer.mlp.gate_up_proj.transpose_weight() @@ -550,24 +559,7 @@ class FlashLlamaModel(torch.nn.Module): cu_seqlens_q, ) - # Faster final layer norm - hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.norm.weight, - None, - None, - None, - None, - None, - 0.0, - self.norm.variance_epsilon, - 1.0, - 0, - None, - False, - True, - ) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, past_key_values diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 0403c9f6..3029ab89 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -32,10 +32,10 @@ class FlashLlama(FlashCausalLM): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 else: - raise NotImplementedError("FlashCausalLM is only available on GPU") + raise NotImplementedError("FlashLlama is only available on GPU") if quantize: - raise NotImplementedError("FlashCausalLM does not support quantization") + raise NotImplementedError("FlashLlama does not support quantization") tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left" @@ -45,6 +45,7 @@ class FlashLlama(FlashCausalLM): model_id, revision=revision, tp_parallel=True ) + # We do not use from_pretrained as we modified the model internal module layout try: filenames = weight_files(model_id, revision, ".bin") # Local files not found @@ -71,73 +72,65 @@ class FlashLlama(FlashCausalLM): model, filenames: List[Path], ): - final_state_dict = {} for filename in filenames: state_dict = torch.load(filename, map_location="cpu") for key, value in state_dict.items(): layer_name = ".".join(key.split(".")[:4]) - if "q_proj" in key: + + # Fused qkv + if "q_proj" in key or "k_proj" in key or "v_proj" in key: final_key = layer_name + ".query_key_value.weight" - if final_key not in final_state_dict: - final_state_dict[final_key] = value.new_empty( - (value.shape[0] * 3, value.shape[1]) - ) - final_state_dict[final_key][: value.shape[0]] = value - elif "k_proj" in key: - final_key = layer_name + ".query_key_value.weight" - if final_key not in final_state_dict: - final_state_dict[final_key] = value.new_empty( - (value.shape[0] * 3, value.shape[1]) - ) - final_state_dict[final_key][ - value.shape[0] : value.shape[0] * 2 - ] = value - elif "v_proj" in key: - final_key = layer_name + ".query_key_value.weight" - if final_key not in final_state_dict: - final_state_dict[final_key] = value.new_empty( - (value.shape[0] * 3, value.shape[1]) - ) - final_state_dict[final_key][value.shape[0] * 2 :] = value - elif "gate_proj" in key: + + # Fused gate and up projs + elif "gate_proj" in key or "up_proj" in key: final_key = layer_name + ".gate_up_proj.weight" - if final_key not in final_state_dict: - final_state_dict[final_key] = value.new_empty( - (value.shape[0] * 2, value.shape[1]) - ) - final_state_dict[final_key][: value.shape[0]] = value - elif "up_proj" in key: - final_key = layer_name + ".gate_up_proj.weight" - if final_key not in final_state_dict: - final_state_dict[final_key] = value.new_empty( - (value.shape[0] * 2, value.shape[1]) - ) - final_state_dict[final_key][value.shape[0] :] = value else: - final_state_dict[key] = value - del state_dict + final_key = key - parameters = dict(model.named_parameters()) - for key, value in final_state_dict.items(): - current_parameter_tensor = parameters.get(key, None) - module_name, param_name = key.rsplit(".", 1) - module = model.get_submodule(module_name) + module_name, param_name = final_key.rsplit(".", 1) + module = model.get_submodule(module_name) - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != value.shape - ): - raise ValueError( - f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None - value = value.contiguous() + if current_parameter_tensor is not None: + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "query_key_value" in final_key: + module._parameters[param_name] = value.new_empty( + (value.shape[0] * 3, value.shape[1]) + ) + # Init gate and up proj + elif "gate_up_proj" in final_key: + module._parameters[param_name] = value.new_empty( + (value.shape[0] * 2, value.shape[1]) + ) - if current_parameter_tensor is not None: - module._parameters[param_name] = value - else: - module._buffers[param_name] = value + # Copy to correct slice + if "q_proj" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "k_proj" in key: + module._parameters[param_name][ + value.shape[0] : value.shape[0] * 2 + ] = value + elif "v_proj" in key: + module._parameters[param_name][value.shape[0] * 2 :] = value + elif "gate_proj" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "up_proj" in key: + module._parameters[param_name][value.shape[0] :] = value + else: + if current_parameter_tensor.shape != value.shape: + raise ValueError( + f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + module._parameters[param_name] = value + else: + module._buffers[param_name] = value + torch.cuda.empty_cache() model.post_load_weights() @@ -168,7 +161,7 @@ class FlashLlamaSharded(FlashLlama): filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - model = FlashGPTNeoXForCausalLM(config) + model = FlashLlamaForCausalLM(config, process_group=self.process_group) torch.distributed.barrier(group=self.process_group) self.load_weights( @@ -179,7 +172,6 @@ class FlashLlamaSharded(FlashLlama): rank=self.rank, world_size=self.world_size, ) - model.post_load_weights() self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( @@ -196,19 +188,28 @@ class FlashLlamaSharded(FlashLlama): rank: int, world_size: int, ): - parameters = dict(model.named_parameters()) for file in filenames: with safe_open( file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - slice_ = f.get_slice(name) + layer_name = ".".join(name.split(".")[:4]) + + # Fused qkv + if "q_proj" in name or "k_proj" in name or "v_proj" in name: + final_name = layer_name + ".query_key_value.weight" + + # Fused gate and up projs + elif "gate_proj" in name or "up_proj" in name: + final_name = layer_name + ".gate_up_proj.weight" + else: + final_name = name + + module_name, param_name = final_name.rsplit(".", 1) + module = model.get_submodule(module_name) + if isinstance(module, TensorParallelColumnLinear): size = slice_.get_shape()[0] block_size = size // world_size @@ -216,24 +217,18 @@ class FlashLlamaSharded(FlashLlama): stop = (rank + 1) * block_size tensor = slice_[start:stop] elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] elif isinstance(module, TensorParallelEmbedding): size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[start:stop] - elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: + elif name == "lm_head.weight" and model.model.tp_embeddings: size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size @@ -245,20 +240,53 @@ class FlashLlamaSharded(FlashLlama): except: tensor = f.get_tensor(name) - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - tensor = tensor.contiguous() + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None + if current_parameter_tensor is not None: - module._parameters[param_name] = tensor + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "query_key_value" in final_name: + module._parameters[param_name] = tensor.new_empty( + (tensor.shape[0] * 3, tensor.shape[1]) + ) + # Init gate and up proj + elif "gate_up_proj" in final_name: + module._parameters[param_name] = tensor.new_empty( + (tensor.shape[0] * 2, tensor.shape[1]) + ) + + # Init gate and up proj + if "q_proj" in name: + module._parameters[param_name][: tensor.shape[0]] = tensor + elif "k_proj" in name: + module._parameters[param_name][ + tensor.shape[0] : tensor.shape[0] * 2 + ] = tensor + elif "v_proj" in name: + module._parameters[param_name][ + tensor.shape[0] * 2 : + ] = tensor + elif "gate_proj" in name: + module._parameters[param_name][: tensor.shape[0]] = tensor + elif "up_proj" in name: + module._parameters[param_name][tensor.shape[0] :] = tensor + else: + if current_parameter_tensor.shape != tensor.shape: + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + module._parameters[param_name] = tensor + else: module._buffers[param_name] = tensor + torch.cuda.empty_cache() + model.post_load_weights() def forward( self, @@ -268,7 +296,7 @@ class FlashLlamaSharded(FlashLlama): max_s: int, past_key_values: Optional = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.model.gpt_neox.tp_embeddings: + if self.model.model.tp_embeddings: logits, present = self.model.forward( input_ids=input_ids, position_ids=position_ids,