This commit is contained in:
OlivierDehaene 2023-06-01 18:37:14 +02:00
parent c9e7471742
commit bfd6928c3e
3 changed files with 108 additions and 101 deletions

View File

@ -136,7 +136,7 @@ class FlashRWAttention(torch.nn.Module):
end_seq_q, end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
prefill, prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -153,12 +153,12 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill # Prefill
if prefill: if prefill:
# Copy to layer past # Copy to layer past
layer_past[layer_past_present_indices] = kv layer_past[past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size) kv = kv.expand(-1, 2, self.num_heads, self.head_size)
@ -167,8 +167,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda_modif.fwd(
query, query,
kv[:, 0], torch.select(kv, dim=1, index=0),
kv[:, 1], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
start_seq, start_seq,
end_seq, end_seq,
@ -187,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv layer_past[past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
@ -196,8 +196,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda_modif.fwd(
query, query,
kv[:, 0], torch.select(kv, dim=1, index=0),
kv[:, 1], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
start_seq_q, start_seq_q,
end_seq_q, end_seq_q,
@ -271,7 +271,7 @@ class FlashRWLargeAttention(torch.nn.Module):
cu_seqlens, cu_seqlens,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -290,7 +290,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(kv[:, :, 0], cos, sin) self.rotary_emb(kv[:, :, 0], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if past_present_indices is None:
# Copy to layer past # Copy to layer past
layer_past[...] = kv layer_past[...] = kv
# Expand to query shape # Expand to query shape
@ -323,7 +323,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv layer_past[past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = ( kv = (
layer_past.unsqueeze(2) layer_past.unsqueeze(2)
@ -430,7 +430,7 @@ class FlashRWLayer(nn.Module):
end_seq_q, end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
prefill, prefill,
): ):
if self.parallel_attn: if self.parallel_attn:
@ -446,7 +446,7 @@ class FlashRWLayer(nn.Module):
end_seq_q, end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
prefill, prefill,
) )
@ -469,7 +469,7 @@ class FlashRWLayer(nn.Module):
end_seq_q, end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
prefill, prefill,
) )
@ -517,7 +517,7 @@ class FlashRWLargeLayer(nn.Module):
cu_seqlens, cu_seqlens,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
ln_attn, residual = self.ln_attn(hidden_states, residual) ln_attn, residual = self.ln_attn(hidden_states, residual)
@ -531,7 +531,7 @@ class FlashRWLargeLayer(nn.Module):
cu_seqlens, cu_seqlens,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, cu_seqlens_q,
) )
@ -619,8 +619,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
# Create past tensor # Create past tensor
past_key_values = hidden_states.new_zeros( past_key_values = hidden_states.new_zeros(
( (
len(self.h),
pre_allocate_past_size, pre_allocate_past_size,
len(self.h),
*self.cache_size, *self.cache_size,
) )
) )
@ -646,7 +646,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
start_seq_q, start_seq_q,
end_seq_q, end_seq_q,
max_s, max_s,
past_key_values[i], past_key_values[:, i],
past_present_indices, past_present_indices,
prefill, prefill,
) )

View File

@ -7,6 +7,7 @@ from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda_modif import flash_attn_cuda_modif
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
@ -166,9 +170,9 @@ class FlashMQAttention(torch.nn.Module):
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = key_value layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads # Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda_modif.fwd(
query, query,
key_value[:, 0], torch.select(key_value, dim=1, index=0),
key_value[:, 1], torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = key_value layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads # Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda_modif.fwd(
query, query,
key_value[:, 0], torch.select(key_value, dim=1, index=0),
key_value[:, 1], torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
@ -277,21 +285,27 @@ class Block(nn.Module):
self, self,
hidden_states, hidden_states,
residual, residual,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.attn(
hidden_states, hidden_states,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states, residual = self.ln_2(hidden_states, residual)
@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -350,43 +367,37 @@ class FlashSantacoderModel(nn.Module):
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor # Create past tensor
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_zeros(
( (
pre_allocate_past_size,
len(self.h), len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2, 2,
1, 1,
self.head_size, self.head_size
) )
) )
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths prefill = False
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past_key_values, torch.select(past_key_values, dim=1, index=i),
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
@ -406,9 +417,12 @@ class FlashSantacoderForCausalLM(nn.Module):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -416,9 +430,12 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )

View File

@ -186,8 +186,7 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
request_past_present_indices = np.zeros(input_length + max_new_tokens - 1) request_past_present_indices = torch.arange(cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64)
request_past_present_indices[:input_length] = 1
past_present_indices.append(request_past_present_indices) past_present_indices.append(request_past_present_indices)
# Update # Update
@ -210,10 +209,20 @@ class FlashCausalLMBatch(Batch):
if len(pb.requests) > 1: if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64) input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
else: else:
input_ids = all_input_ids[0] input_ids = all_input_ids[0]
position_ids = position_ids[0] position_ids = position_ids[0]
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
# Create tensors on device # Create tensors on device
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
all_input_ids_tensor = torch.tensor( all_input_ids_tensor = torch.tensor(
@ -222,19 +231,7 @@ class FlashCausalLMBatch(Batch):
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
if len(pb.requests) > 1:
past_present_indices = np.concatenate(past_present_indices)
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
else:
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool)
if all_prefill_logprobs: if all_prefill_logprobs:
prefill_head_indices = None prefill_head_indices = None
@ -298,7 +295,7 @@ class FlashCausalLMBatch(Batch):
indices = [] indices = []
# past indices to keep # past indices to keep
past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device) past_indices = torch.zeros(self.past_key_values.shape[0], dtype=torch.bool, device=device)
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
start_seq = torch.empty(len(request_ids), dtype=torch.int32) start_seq = torch.empty(len(request_ids), dtype=torch.int32)
@ -352,7 +349,7 @@ class FlashCausalLMBatch(Batch):
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
past_key_values = self.past_key_values[:, past_indices] past_key_values = self.past_key_values[past_indices]
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
start_seq = start_seq.to(device) start_seq = start_seq.to(device)
@ -409,11 +406,7 @@ class FlashCausalLMBatch(Batch):
) )
end_seq_q = start_seq_q + 1 end_seq_q = start_seq_q + 1
max_seqlen = 0 max_seqlen = 0
past_key_values = batches[0].past_key_values.new_empty(( past_key_values = []
batches[0].past_key_values.shape[0],
total_tokens,
*batches[0].past_key_values.shape[2:]
))
all_input_ids = [] all_input_ids = []
@ -449,11 +442,6 @@ class FlashCausalLMBatch(Batch):
start_seq[start_index:end_index] = batch.start_seq + max_tokens start_seq[start_index:end_index] = batch.start_seq + max_tokens
end_seq[start_index:end_index] = batch.end_seq + max_tokens end_seq[start_index:end_index] = batch.end_seq + max_tokens
past_key_values[
:,
max_tokens: max_tokens + batch.max_tokens
] = batch.past_key_values
max_seqlen = max(max_seqlen, batch.max_seqlen) max_seqlen = max(max_seqlen, batch.max_seqlen)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
@ -464,6 +452,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
past_key_values.append(batch.past_key_values)
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
@ -480,6 +469,7 @@ class FlashCausalLMBatch(Batch):
), ),
) )
past_key_values = torch.cat(past_key_values, dim=0)
past_present_indices = end_seq - 1 past_present_indices = end_seq - 1
all_input_ids_tensor = torch.zeros( all_input_ids_tensor = torch.zeros(
@ -726,8 +716,8 @@ class FlashCausalLM(Model):
# Set values in batch # Set values in batch
batch.input_ids = next_input_ids batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1 batch.position_ids = next_position_ids + 1
batch.past_present_indices = torch.clone(batch.end_seq) batch.past_present_indices = batch.end_seq
batch.end_seq += 1 batch.end_seq = batch.end_seq + 1
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs