feat(server): add flash attention llama

This commit is contained in:
OlivierDehaene 2023-03-28 16:12:05 +02:00
parent 71402ed4c7
commit cd5d0a96ba
3 changed files with 190 additions and 166 deletions

View File

@ -19,7 +19,7 @@ from text_generation_server.models.t5 import T5Sharded
try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_santacoder import FlashSantacoder
from text_generation_server.models.flash_llama import FlashLlama
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
FLASH_ATTENTION = (
torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1
@ -95,7 +95,11 @@ def get_model(
if model_type == "llama":
if sharded:
raise NotImplementedError
if FLASH_ATTENTION:
return FlashLlamaSharded(model_id, revision, quantize=quantize)
raise NotImplementedError(
"sharded is not supported for llama when FLASH_ATTENTION=0"
)
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(model_id, revision, quantize=quantize)

View File

@ -1,3 +1,23 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
@ -23,15 +43,45 @@ class LlamaRMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 6144:
if residual is not None:
hidden_states += residual
residual = hidden_states
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
return self.weight * hidden_states
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
else:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
class FastLinear(nn.Linear):
@ -183,8 +233,6 @@ class PositionRotaryEmbedding(RotaryEmbedding):
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
@ -245,16 +293,6 @@ class FlashLlamaAttention(torch.nn.Module):
process_group=process_group,
)
def shuffle_qkv_dims(self):
"""Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
)
.permute(1, 0, 2, 3)
.reshape(-1, self.hidden_size)
)
def forward(
self,
hidden_states,
@ -333,7 +371,6 @@ class LlamaMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
self.intermediate_size = intermediate_size
if process_group is None:
# Fuse gate and up proj
@ -341,6 +378,7 @@ class LlamaMLP(nn.Module):
hidden_size, 2 * intermediate_size, bias=False
)
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
self.intermediate_size = intermediate_size
else:
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear(
@ -356,6 +394,8 @@ class LlamaMLP(nn.Module):
process_group=process_group,
reduce=True,
)
self.intermediate_size = self.down_proj.in_features
self.process_group = process_group
def forward(self, hidden_states):
@ -394,27 +434,11 @@ class FlashLlamaLayer(nn.Module):
layer_past_present_indices,
cu_seqlens_q,
):
# faster input rms norm
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.input_layernorm.weight,
None,
None,
None,
None,
None,
0.0,
self.input_layernorm.variance_epsilon,
1.0,
0,
None,
False,
True,
)
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
hidden_states,
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlens,
@ -425,27 +449,13 @@ class FlashLlamaLayer(nn.Module):
)
# faster post attention rms norm
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.post_attention_layernorm.weight,
None,
None,
None,
None,
None,
0.0,
self.post_attention_layernorm.variance_epsilon,
1.0,
0,
None,
False,
True,
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(hidden_states)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, residual
return mlp_output, attn_res
class FlashLlamaModel(torch.nn.Module):
@ -492,7 +502,6 @@ class FlashLlamaModel(torch.nn.Module):
self.embed_tokens.add_null_idx()
for layer in self.layers:
layer: FlashLlamaLayer
layer.self_attn.shuffle_qkv_dims()
layer.self_attn.query_key_value.transpose_weight()
layer.self_attn.o_proj.transpose_weight()
layer.mlp.gate_up_proj.transpose_weight()
@ -550,24 +559,7 @@ class FlashLlamaModel(torch.nn.Module):
cu_seqlens_q,
)
# Faster final layer norm
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.norm.weight,
None,
None,
None,
None,
None,
0.0,
self.norm.variance_epsilon,
1.0,
0,
None,
False,
True,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, past_key_values

View File

@ -32,10 +32,10 @@ class FlashLlama(FlashCausalLM):
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashCausalLM is only available on GPU")
raise NotImplementedError("FlashLlama is only available on GPU")
if quantize:
raise NotImplementedError("FlashCausalLM does not support quantization")
raise NotImplementedError("FlashLlama does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
@ -45,6 +45,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, tp_parallel=True
)
# We do not use from_pretrained as we modified the model internal module layout
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
@ -71,73 +72,65 @@ class FlashLlama(FlashCausalLM):
model,
filenames: List[Path],
):
final_state_dict = {}
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
layer_name = ".".join(key.split(".")[:4])
if "q_proj" in key:
# Fused qkv
if "q_proj" in key or "k_proj" in key or "v_proj" in key:
final_key = layer_name + ".query_key_value.weight"
if final_key not in final_state_dict:
final_state_dict[final_key] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
final_state_dict[final_key][: value.shape[0]] = value
elif "k_proj" in key:
final_key = layer_name + ".query_key_value.weight"
if final_key not in final_state_dict:
final_state_dict[final_key] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
final_state_dict[final_key][
value.shape[0] : value.shape[0] * 2
] = value
elif "v_proj" in key:
final_key = layer_name + ".query_key_value.weight"
if final_key not in final_state_dict:
final_state_dict[final_key] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
final_state_dict[final_key][value.shape[0] * 2 :] = value
elif "gate_proj" in key:
# Fused gate and up projs
elif "gate_proj" in key or "up_proj" in key:
final_key = layer_name + ".gate_up_proj.weight"
if final_key not in final_state_dict:
final_state_dict[final_key] = value.new_empty(
(value.shape[0] * 2, value.shape[1])
)
final_state_dict[final_key][: value.shape[0]] = value
elif "up_proj" in key:
final_key = layer_name + ".gate_up_proj.weight"
if final_key not in final_state_dict:
final_state_dict[final_key] = value.new_empty(
(value.shape[0] * 2, value.shape[1])
)
final_state_dict[final_key][value.shape[0] :] = value
else:
final_state_dict[key] = value
del state_dict
final_key = key
parameters = dict(model.named_parameters())
for key, value in final_state_dict.items():
current_parameter_tensor = parameters.get(key, None)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != value.shape
):
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
value = value.contiguous()
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 2, value.shape[1])
)
if current_parameter_tensor is not None:
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
# Copy to correct slice
if "q_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "k_proj" in key:
module._parameters[param_name][
value.shape[0] : value.shape[0] * 2
] = value
elif "v_proj" in key:
module._parameters[param_name][value.shape[0] * 2 :] = value
elif "gate_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "up_proj" in key:
module._parameters[param_name][value.shape[0] :] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
torch.cuda.empty_cache()
model.post_load_weights()
@ -168,7 +161,7 @@ class FlashLlamaSharded(FlashLlama):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config)
model = FlashLlamaForCausalLM(config, process_group=self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
@ -179,7 +172,6 @@ class FlashLlamaSharded(FlashLlama):
rank=self.rank,
world_size=self.world_size,
)
model.post_load_weights()
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
@ -196,19 +188,28 @@ class FlashLlamaSharded(FlashLlama):
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
layer_name = ".".join(name.split(".")[:4])
# Fused qkv
if "q_proj" in name or "k_proj" in name or "v_proj" in name:
final_name = layer_name + ".query_key_value.weight"
# Fused gate and up projs
elif "gate_proj" in name or "up_proj" in name:
final_name = layer_name + ".gate_up_proj.weight"
else:
final_name = name
module_name, param_name = final_name.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
@ -216,24 +217,18 @@ class FlashLlamaSharded(FlashLlama):
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
elif name == "lm_head.weight" and model.model.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
@ -245,20 +240,53 @@ class FlashLlamaSharded(FlashLlama):
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_name:
module._parameters[param_name] = tensor.new_empty(
(tensor.shape[0] * 3, tensor.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_name:
module._parameters[param_name] = tensor.new_empty(
(tensor.shape[0] * 2, tensor.shape[1])
)
# Init gate and up proj
if "q_proj" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "k_proj" in name:
module._parameters[param_name][
tensor.shape[0] : tensor.shape[0] * 2
] = tensor
elif "v_proj" in name:
module._parameters[param_name][
tensor.shape[0] * 2 :
] = tensor
elif "gate_proj" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "up_proj" in name:
module._parameters[param_name][tensor.shape[0] :] = tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights()
def forward(
self,
@ -268,7 +296,7 @@ class FlashLlamaSharded(FlashLlama):
max_s: int,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.model.gpt_neox.tp_embeddings:
if self.model.model.tp_embeddings:
logits, present = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,