mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
feat(server): add flash attention llama
This commit is contained in:
parent
71402ed4c7
commit
cd5d0a96ba
@ -19,7 +19,7 @@ from text_generation_server.models.t5 import T5Sharded
|
||||
try:
|
||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||
from text_generation_server.models.flash_santacoder import FlashSantacoder
|
||||
from text_generation_server.models.flash_llama import FlashLlama
|
||||
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
|
||||
|
||||
FLASH_ATTENTION = (
|
||||
torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1
|
||||
@ -95,7 +95,11 @@ def get_model(
|
||||
|
||||
if model_type == "llama":
|
||||
if sharded:
|
||||
raise NotImplementedError
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlamaSharded(model_id, revision, quantize=quantize)
|
||||
raise NotImplementedError(
|
||||
"sharded is not supported for llama when FLASH_ATTENTION=0"
|
||||
)
|
||||
else:
|
||||
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
||||
return llama_cls(model_id, revision, quantize=quantize)
|
||||
|
@ -1,3 +1,23 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -23,15 +43,45 @@ class LlamaRMSNorm(nn.Module):
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 6144:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
return self.weight * hidden_states
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
else:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
class FastLinear(nn.Linear):
|
||||
@ -183,8 +233,6 @@ class PositionRotaryEmbedding(RotaryEmbedding):
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
@ -245,16 +293,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
def shuffle_qkv_dims(self):
|
||||
"""Swap dims to avoid an additional permute"""
|
||||
self.query_key_value.weight = torch.nn.Parameter(
|
||||
self.query_key_value.weight.view(
|
||||
self.num_heads, 3, self.head_size, self.hidden_size
|
||||
)
|
||||
.permute(1, 0, 2, 3)
|
||||
.reshape(-1, self.hidden_size)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -333,7 +371,6 @@ class LlamaMLP(nn.Module):
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
|
||||
)
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
if process_group is None:
|
||||
# Fuse gate and up proj
|
||||
@ -341,6 +378,7 @@ class LlamaMLP(nn.Module):
|
||||
hidden_size, 2 * intermediate_size, bias=False
|
||||
)
|
||||
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
|
||||
self.intermediate_size = intermediate_size
|
||||
else:
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear(
|
||||
@ -356,6 +394,8 @@ class LlamaMLP(nn.Module):
|
||||
process_group=process_group,
|
||||
reduce=True,
|
||||
)
|
||||
self.intermediate_size = self.down_proj.in_features
|
||||
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -394,27 +434,11 @@ class FlashLlamaLayer(nn.Module):
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
# faster input rms norm
|
||||
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.input_layernorm.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.input_layernorm.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states,
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
@ -425,27 +449,13 @@ class FlashLlamaLayer(nn.Module):
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.post_attention_layernorm.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True,
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, residual
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
@ -492,7 +502,6 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
self.embed_tokens.add_null_idx()
|
||||
for layer in self.layers:
|
||||
layer: FlashLlamaLayer
|
||||
layer.self_attn.shuffle_qkv_dims()
|
||||
layer.self_attn.query_key_value.transpose_weight()
|
||||
layer.self_attn.o_proj.transpose_weight()
|
||||
layer.mlp.gate_up_proj.transpose_weight()
|
||||
@ -550,24 +559,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
# Faster final layer norm
|
||||
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.norm.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.norm.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
||||
|
@ -32,10 +32,10 @@ class FlashLlama(FlashCausalLM):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
else:
|
||||
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashCausalLM does not support quantization")
|
||||
raise NotImplementedError("FlashLlama does not support quantization")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left"
|
||||
@ -45,6 +45,7 @@ class FlashLlama(FlashCausalLM):
|
||||
model_id, revision=revision, tp_parallel=True
|
||||
)
|
||||
|
||||
# We do not use from_pretrained as we modified the model internal module layout
|
||||
try:
|
||||
filenames = weight_files(model_id, revision, ".bin")
|
||||
# Local files not found
|
||||
@ -71,73 +72,65 @@ class FlashLlama(FlashCausalLM):
|
||||
model,
|
||||
filenames: List[Path],
|
||||
):
|
||||
final_state_dict = {}
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
if "q_proj" in key:
|
||||
|
||||
# Fused qkv
|
||||
if "q_proj" in key or "k_proj" in key or "v_proj" in key:
|
||||
final_key = layer_name + ".query_key_value.weight"
|
||||
if final_key not in final_state_dict:
|
||||
final_state_dict[final_key] = value.new_empty(
|
||||
(value.shape[0] * 3, value.shape[1])
|
||||
)
|
||||
final_state_dict[final_key][: value.shape[0]] = value
|
||||
elif "k_proj" in key:
|
||||
final_key = layer_name + ".query_key_value.weight"
|
||||
if final_key not in final_state_dict:
|
||||
final_state_dict[final_key] = value.new_empty(
|
||||
(value.shape[0] * 3, value.shape[1])
|
||||
)
|
||||
final_state_dict[final_key][
|
||||
value.shape[0] : value.shape[0] * 2
|
||||
] = value
|
||||
elif "v_proj" in key:
|
||||
final_key = layer_name + ".query_key_value.weight"
|
||||
if final_key not in final_state_dict:
|
||||
final_state_dict[final_key] = value.new_empty(
|
||||
(value.shape[0] * 3, value.shape[1])
|
||||
)
|
||||
final_state_dict[final_key][value.shape[0] * 2 :] = value
|
||||
elif "gate_proj" in key:
|
||||
|
||||
# Fused gate and up projs
|
||||
elif "gate_proj" in key or "up_proj" in key:
|
||||
final_key = layer_name + ".gate_up_proj.weight"
|
||||
if final_key not in final_state_dict:
|
||||
final_state_dict[final_key] = value.new_empty(
|
||||
(value.shape[0] * 2, value.shape[1])
|
||||
)
|
||||
final_state_dict[final_key][: value.shape[0]] = value
|
||||
elif "up_proj" in key:
|
||||
final_key = layer_name + ".gate_up_proj.weight"
|
||||
if final_key not in final_state_dict:
|
||||
final_state_dict[final_key] = value.new_empty(
|
||||
(value.shape[0] * 2, value.shape[1])
|
||||
)
|
||||
final_state_dict[final_key][value.shape[0] :] = value
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
del state_dict
|
||||
final_key = key
|
||||
|
||||
parameters = dict(model.named_parameters())
|
||||
for key, value in final_state_dict.items():
|
||||
current_parameter_tensor = parameters.get(key, None)
|
||||
module_name, param_name = key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
module_name, param_name = final_key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != value.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||
)
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
except KeyError:
|
||||
current_parameter_tensor = None
|
||||
|
||||
value = value.contiguous()
|
||||
if current_parameter_tensor is not None:
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "query_key_value" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(value.shape[0] * 3, value.shape[1])
|
||||
)
|
||||
# Init gate and up proj
|
||||
elif "gate_up_proj" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(value.shape[0] * 2, value.shape[1])
|
||||
)
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = value
|
||||
else:
|
||||
module._buffers[param_name] = value
|
||||
# Copy to correct slice
|
||||
if "q_proj" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "k_proj" in key:
|
||||
module._parameters[param_name][
|
||||
value.shape[0] : value.shape[0] * 2
|
||||
] = value
|
||||
elif "v_proj" in key:
|
||||
module._parameters[param_name][value.shape[0] * 2 :] = value
|
||||
elif "gate_proj" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "up_proj" in key:
|
||||
module._parameters[param_name][value.shape[0] :] = value
|
||||
else:
|
||||
if current_parameter_tensor.shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||
)
|
||||
module._parameters[param_name] = value
|
||||
else:
|
||||
module._buffers[param_name] = value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights()
|
||||
|
||||
|
||||
@ -168,7 +161,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashGPTNeoXForCausalLM(config)
|
||||
model = FlashLlamaForCausalLM(config, process_group=self.process_group)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
@ -179,7 +172,6 @@ class FlashLlamaSharded(FlashLlama):
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
model.post_load_weights()
|
||||
self.model = model.eval().to(dtype)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
@ -196,19 +188,28 @@ class FlashLlamaSharded(FlashLlama):
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
current_parameter_tensor = parameters.get(name, None)
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
layer_name = ".".join(name.split(".")[:4])
|
||||
|
||||
# Fused qkv
|
||||
if "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
final_name = layer_name + ".query_key_value.weight"
|
||||
|
||||
# Fused gate and up projs
|
||||
elif "gate_proj" in name or "up_proj" in name:
|
||||
final_name = layer_name + ".gate_up_proj.weight"
|
||||
else:
|
||||
final_name = name
|
||||
|
||||
module_name, param_name = final_name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
@ -216,24 +217,18 @@ class FlashLlamaSharded(FlashLlama):
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
||||
elif name == "lm_head.weight" and model.model.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
@ -245,20 +240,53 @@ class FlashLlamaSharded(FlashLlama):
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
except KeyError:
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "query_key_value" in final_name:
|
||||
module._parameters[param_name] = tensor.new_empty(
|
||||
(tensor.shape[0] * 3, tensor.shape[1])
|
||||
)
|
||||
# Init gate and up proj
|
||||
elif "gate_up_proj" in final_name:
|
||||
module._parameters[param_name] = tensor.new_empty(
|
||||
(tensor.shape[0] * 2, tensor.shape[1])
|
||||
)
|
||||
|
||||
# Init gate and up proj
|
||||
if "q_proj" in name:
|
||||
module._parameters[param_name][: tensor.shape[0]] = tensor
|
||||
elif "k_proj" in name:
|
||||
module._parameters[param_name][
|
||||
tensor.shape[0] : tensor.shape[0] * 2
|
||||
] = tensor
|
||||
elif "v_proj" in name:
|
||||
module._parameters[param_name][
|
||||
tensor.shape[0] * 2 :
|
||||
] = tensor
|
||||
elif "gate_proj" in name:
|
||||
module._parameters[param_name][: tensor.shape[0]] = tensor
|
||||
elif "up_proj" in name:
|
||||
module._parameters[param_name][tensor.shape[0] :] = tensor
|
||||
else:
|
||||
if current_parameter_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -268,7 +296,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||
max_s: int,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.model.gpt_neox.tp_embeddings:
|
||||
if self.model.model.tp_embeddings:
|
||||
logits, present = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user