mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
working
This commit is contained in:
parent
73cf93f1ee
commit
bbb1d9e704
@ -257,6 +257,10 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
|
||||
self.num_groups = self.num_groups // process_group.size()
|
||||
|
||||
self.num_heads_config = num_heads
|
||||
self.num_heads_kv_config = num_heads_kv
|
||||
self.num_groups = 64
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -268,37 +272,56 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
cu_shape = hidden_states.shape[0]
|
||||
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||
qkv = qkv.view(cu_shape, -1, self.num_heads_config // self.num_heads_kv_config +2, 64)
|
||||
q = qkv[:, :, :-2]
|
||||
k = qkv[:, :, [-2]]
|
||||
v = qkv[:, :, [-1]]
|
||||
|
||||
# Split query from key_value
|
||||
query, kv = qkv.split(
|
||||
[self.num_heads, 2],
|
||||
dim=2,
|
||||
)
|
||||
k = torch.broadcast_to(k, q.shape)
|
||||
v = torch.broadcast_to(v, q.shape)
|
||||
|
||||
# Prepare query and key_value for indexing
|
||||
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
kv = kv.transpose(1, 2)
|
||||
q = q.reshape(cu_shape, -1, self.head_size)
|
||||
k = k.reshape(cu_shape, -1, self.head_size)
|
||||
v = v.reshape(cu_shape, -1, self.head_size)
|
||||
|
||||
logger.error(k.shape)
|
||||
|
||||
# qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||
#
|
||||
# # Split query from key_value
|
||||
# query, kv = qkv.split(
|
||||
# [self.num_heads, 2],
|
||||
# dim=2,
|
||||
# )
|
||||
#
|
||||
# # Prepare query and key_value for indexing
|
||||
# query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
# kv = kv.transpose(1, 2)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(kv[:, 0], cos, sin)
|
||||
self.rotary_emb(q, cos, sin)
|
||||
self.rotary_emb(k, cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
k, v = kv.split(1, dim=1)
|
||||
# layer_past[...] = kv
|
||||
# k, v = kv.split(1, dim=1)
|
||||
# Expand to query shape
|
||||
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
# k = k.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
# v = v.transpose(1, 2).expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
|
||||
layer_past[:, 0] = k
|
||||
layer_past[:, 1] = v
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
attn_output = torch.empty_like(q)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_output,
|
||||
@ -317,19 +340,22 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
k, v = layer_past.split(1, dim=1)
|
||||
# layer_past[layer_past_present_indices] = kv
|
||||
# k, v = layer_past.split(1, dim=1)
|
||||
# Expand to query shape
|
||||
k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
# k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
# v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
|
||||
layer_past[layer_past_present_indices, 0] = k
|
||||
layer_past[layer_past_present_indices, 1] = v
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
attn_output = torch.empty_like(q)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
q,
|
||||
layer_past[:, 0],
|
||||
layer_past[:, 1],
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
@ -344,7 +370,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
None,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.num_groups * self.head_size))
|
||||
return self.dense(attn_output.view(cu_shape, -1))
|
||||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
@ -498,8 +524,8 @@ class FlashRWLargeLayer(nn.Module):
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(hidden_states, residual)
|
||||
ln_attn, _ = self.ln_attn(hidden_states)
|
||||
ln_mlp, _ = self.ln_mlp(hidden_states)
|
||||
|
||||
# Self attention.
|
||||
attn_output = self.self_attention(
|
||||
@ -522,7 +548,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate, residual
|
||||
return intermediate + hidden_states, None
|
||||
|
||||
|
||||
class FlashRWPreTrainedModel(PreTrainedModel):
|
||||
|
Loading…
Reference in New Issue
Block a user