mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
40b working
This commit is contained in:
parent
bbb1d9e704
commit
8c8d709994
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
@ -257,10 +256,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.num_groups = self.num_groups // process_group.size()
|
self.num_groups = self.num_groups // process_group.size()
|
||||||
|
|
||||||
self.num_heads_config = num_heads
|
|
||||||
self.num_heads_kv_config = num_heads_kv
|
|
||||||
self.num_groups = 64
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -272,56 +267,32 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
cu_shape = hidden_states.shape[0]
|
|
||||||
|
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(cu_shape, -1, self.num_heads_config // self.num_heads_kv_config +2, 64)
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
q = qkv[:, :, :-2]
|
query, kv = qkv.split(
|
||||||
k = qkv[:, :, [-2]]
|
[self.num_heads, 2],
|
||||||
v = qkv[:, :, [-1]]
|
dim=2,
|
||||||
|
)
|
||||||
k = torch.broadcast_to(k, q.shape)
|
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
v = torch.broadcast_to(v, q.shape)
|
|
||||||
|
|
||||||
q = q.reshape(cu_shape, -1, self.head_size)
|
|
||||||
k = k.reshape(cu_shape, -1, self.head_size)
|
|
||||||
v = v.reshape(cu_shape, -1, self.head_size)
|
|
||||||
|
|
||||||
logger.error(k.shape)
|
|
||||||
|
|
||||||
# qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
|
||||||
#
|
|
||||||
# # Split query from key_value
|
|
||||||
# query, kv = qkv.split(
|
|
||||||
# [self.num_heads, 2],
|
|
||||||
# dim=2,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# # Prepare query and key_value for indexing
|
|
||||||
# query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
|
||||||
# kv = kv.transpose(1, 2)
|
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(q, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(k, cos, sin)
|
self.rotary_emb(kv[:, :, 0], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if layer_past_present_indices is None:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
# layer_past[...] = kv
|
layer_past[...] = kv
|
||||||
# k, v = kv.split(1, dim=1)
|
k, v = kv.split(1, dim=2)
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
# k = k.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
# v = v.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
|
||||||
layer_past[:, 0] = k
|
|
||||||
layer_past[:, 1] = v
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(q)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
q,
|
query,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
attn_output,
|
attn_output,
|
||||||
@ -340,22 +311,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
# layer_past[layer_past_present_indices] = kv
|
layer_past[layer_past_present_indices] = kv
|
||||||
# k, v = layer_past.split(1, dim=1)
|
k, v = layer_past.split(1, dim=2)
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
# k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
# v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
|
||||||
layer_past[layer_past_present_indices, 0] = k
|
|
||||||
layer_past[layer_past_present_indices, 1] = v
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(q)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
q,
|
query,
|
||||||
layer_past[:, 0],
|
k,
|
||||||
layer_past[:, 1],
|
v,
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -370,7 +338,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(cu_shape, -1))
|
return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
class FlashMLP(nn.Module):
|
class FlashMLP(nn.Module):
|
||||||
@ -591,7 +559,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.kv_size = self.h[0].self_attention.num_heads_kv
|
self.cache_size = (2, self.h[0].self_attention.num_heads_kv, self.h[0].self_attention.head_size)
|
||||||
elif config.model_type == "RefinedWeb":
|
elif config.model_type == "RefinedWeb":
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -606,7 +574,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.kv_size = self.h[0].self_attention.num_groups
|
self.cache_size = (self.h[0].self_attention.num_groups, 2, self.h[0].self_attention.head_size)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"model_type {config.model_type} is not supported."
|
f"model_type {config.model_type} is not supported."
|
||||||
@ -661,9 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
len(hidden_states)
|
len(hidden_states)
|
||||||
if pre_allocate_past_size is None
|
if pre_allocate_past_size is None
|
||||||
else pre_allocate_past_size,
|
else pre_allocate_past_size,
|
||||||
2,
|
*self.cache_size
|
||||||
self.kv_size,
|
|
||||||
self.head_size,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
layer_past_present_indices = None
|
||||||
|
Loading…
Reference in New Issue
Block a user