This commit is contained in:
OlivierDehaene 2023-06-01 18:37:14 +02:00
parent c9e7471742
commit bfd6928c3e
3 changed files with 108 additions and 101 deletions

View File

@ -136,7 +136,7 @@ class FlashRWAttention(torch.nn.Module):
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
@ -153,12 +153,12 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[layer_past_present_indices] = kv
layer_past[past_present_indices] = kv
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
@ -167,8 +167,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda_modif.fwd(
query,
kv[:, 0],
kv[:, 1],
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
start_seq,
end_seq,
@ -187,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
layer_past[past_present_indices] = kv
# Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
@ -196,8 +196,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda_modif.fwd(
query,
kv[:, 0],
kv[:, 1],
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
start_seq_q,
end_seq_q,
@ -271,7 +271,7 @@ class FlashRWLargeAttention(torch.nn.Module):
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
past_present_indices,
cu_seqlens_q,
):
qkv = self.query_key_value(hidden_states)
@ -290,7 +290,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(kv[:, :, 0], cos, sin)
# Prefill
if layer_past_present_indices is None:
if past_present_indices is None:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
@ -323,7 +323,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
layer_past[past_present_indices] = kv
# Expand to query shape
kv = (
layer_past.unsqueeze(2)
@ -430,7 +430,7 @@ class FlashRWLayer(nn.Module):
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
past_present_indices,
prefill,
):
if self.parallel_attn:
@ -446,7 +446,7 @@ class FlashRWLayer(nn.Module):
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
past_present_indices,
prefill,
)
@ -469,7 +469,7 @@ class FlashRWLayer(nn.Module):
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
past_present_indices,
prefill,
)
@ -517,7 +517,7 @@ class FlashRWLargeLayer(nn.Module):
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
past_present_indices,
cu_seqlens_q,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
@ -531,7 +531,7 @@ class FlashRWLargeLayer(nn.Module):
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
past_present_indices,
cu_seqlens_q,
)
@ -619,8 +619,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
# Create past tensor
past_key_values = hidden_states.new_zeros(
(
len(self.h),
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
@ -646,7 +646,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
start_seq_q,
end_seq_q,
max_s,
past_key_values[i],
past_key_values[:, i],
past_present_indices,
prefill,
)

View File

@ -7,6 +7,7 @@ from typing import Optional
# Flash attention imports
import flash_attn_cuda_modif
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
def forward(
self,
hidden_states,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.c_attn(hidden_states)
@ -166,9 +170,9 @@ class FlashMQAttention(torch.nn.Module):
key_value = key_value.view(-1, 2, 1, self.head_size)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = key_value
layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda_modif.fwd(
query,
key_value[:, 0],
key_value[:, 1],
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = key_value
layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda_modif.fwd(
query,
key_value[:, 0],
key_value[:, 1],
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
@ -277,21 +285,27 @@ class Block(nn.Module):
self,
hidden_states,
residual,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -350,43 +367,37 @@ class FlashSantacoderModel(nn.Module):
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
past_key_values = hidden_states.new_empty(
past_key_values = hidden_states.new_zeros(
(
pre_allocate_past_size,
len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
1,
self.head_size,
self.head_size
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
residual = None
for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -406,9 +417,12 @@ class FlashSantacoderForCausalLM(nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -416,9 +430,12 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states, present = self.transformer(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)

View File

@ -186,8 +186,7 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
request_past_present_indices = np.zeros(input_length + max_new_tokens - 1)
request_past_present_indices[:input_length] = 1
request_past_present_indices = torch.arange(cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64)
past_present_indices.append(request_past_present_indices)
# Update
@ -210,10 +209,20 @@ class FlashCausalLMBatch(Batch):
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
# Create tensors on device
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
all_input_ids_tensor = torch.tensor(
@ -222,19 +231,7 @@ class FlashCausalLMBatch(Batch):
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1:
past_present_indices = np.concatenate(past_present_indices)
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
else:
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool)
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
if all_prefill_logprobs:
prefill_head_indices = None
@ -298,7 +295,7 @@ class FlashCausalLMBatch(Batch):
indices = []
# past indices to keep
past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device)
past_indices = torch.zeros(self.past_key_values.shape[0], dtype=torch.bool, device=device)
# Create on CPU to only move to GPU once instead of at every copy
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
@ -352,7 +349,7 @@ class FlashCausalLMBatch(Batch):
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
past_key_values = self.past_key_values[:, past_indices]
past_key_values = self.past_key_values[past_indices]
# Move to GPU now that we have the whole tensor
start_seq = start_seq.to(device)
@ -409,11 +406,7 @@ class FlashCausalLMBatch(Batch):
)
end_seq_q = start_seq_q + 1
max_seqlen = 0
past_key_values = batches[0].past_key_values.new_empty((
batches[0].past_key_values.shape[0],
total_tokens,
*batches[0].past_key_values.shape[2:]
))
past_key_values = []
all_input_ids = []
@ -449,11 +442,6 @@ class FlashCausalLMBatch(Batch):
start_seq[start_index:end_index] = batch.start_seq + max_tokens
end_seq[start_index:end_index] = batch.end_seq + max_tokens
past_key_values[
:,
max_tokens: max_tokens + batch.max_tokens
] = batch.past_key_values
max_seqlen = max(max_seqlen, batch.max_seqlen)
all_input_ids.extend(batch.all_input_ids)
@ -464,6 +452,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
past_key_values.append(batch.past_key_values)
# Update
cumulative_batch_size += len(batch)
@ -480,6 +469,7 @@ class FlashCausalLMBatch(Batch):
),
)
past_key_values = torch.cat(past_key_values, dim=0)
past_present_indices = end_seq - 1
all_input_ids_tensor = torch.zeros(
@ -726,8 +716,8 @@ class FlashCausalLM(Model):
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.past_present_indices = torch.clone(batch.end_seq)
batch.end_seq += 1
batch.past_present_indices = batch.end_seq
batch.end_seq = batch.end_seq + 1
if prefill and prefill_logprobs:
# Get prefill logprobs