mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
working
This commit is contained in:
parent
c9e7471742
commit
bfd6928c3e
@ -136,7 +136,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -153,12 +153,12 @@ class FlashRWAttention(torch.nn.Module):
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(kv[:, 0], cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
layer_past[past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
@ -167,8 +167,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
start_seq,
|
||||
end_seq,
|
||||
@ -187,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
layer_past[past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
@ -196,8 +196,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
@ -271,7 +271,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -290,7 +290,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.rotary_emb(kv[:, :, 0], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
if past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
# Expand to query shape
|
||||
@ -323,7 +323,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
layer_past[past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
layer_past.unsqueeze(2)
|
||||
@ -430,7 +430,7 @@ class FlashRWLayer(nn.Module):
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
@ -446,7 +446,7 @@ class FlashRWLayer(nn.Module):
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
@ -469,7 +469,7 @@ class FlashRWLayer(nn.Module):
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
@ -517,7 +517,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
@ -531,7 +531,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
@ -619,8 +619,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
# Create past tensor
|
||||
past_key_values = hidden_states.new_zeros(
|
||||
(
|
||||
len(self.h),
|
||||
pre_allocate_past_size,
|
||||
len(self.h),
|
||||
*self.cache_size,
|
||||
)
|
||||
)
|
||||
@ -646,7 +646,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_key_values[i],
|
||||
past_key_values[:, i],
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda_modif
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
qkv = self.c_attn(hidden_states)
|
||||
|
||||
@ -166,9 +170,9 @@ class FlashMQAttention(torch.nn.Module):
|
||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[...] = key_value
|
||||
layer_past[past_present_indices] = key_value
|
||||
# Expand from 1 to num_heads
|
||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
key_value[:, 0],
|
||||
key_value[:, 1],
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq,
|
||||
end_seq,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = key_value
|
||||
layer_past[past_present_indices] = key_value
|
||||
# Expand from 1 to num_heads
|
||||
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
key_value[:, 0],
|
||||
key_value[:, 1],
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
@ -277,21 +285,27 @@ class Block(nn.Module):
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
|
||||
hidden_states = self.attn(
|
||||
hidden_states,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||
@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
past_present_indices,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
@ -350,43 +367,37 @@ class FlashSantacoderModel(nn.Module):
|
||||
|
||||
# Prefill
|
||||
if past_key_values is None:
|
||||
assert pre_allocate_past_size is not None
|
||||
|
||||
prefill = True
|
||||
|
||||
# Create past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
past_key_values = hidden_states.new_zeros(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(self.h),
|
||||
len(hidden_states)
|
||||
if pre_allocate_past_size is None
|
||||
else pre_allocate_past_size,
|
||||
2,
|
||||
1,
|
||||
self.head_size,
|
||||
self.head_size
|
||||
)
|
||||
)
|
||||
layer_past_present_indices = None
|
||||
slice_past_index = len(hidden_states)
|
||||
# Decode
|
||||
else:
|
||||
# Create indices from cumulative sequence lengths
|
||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
||||
slice_past_index = None
|
||||
prefill = False
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.h):
|
||||
# We added padding that we now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
else past_key_values[i, :slice_past_index]
|
||||
)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past_key_values,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
torch.select(past_key_values, dim=1, index=i),
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
@ -404,21 +415,27 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
|
@ -186,8 +186,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
request_past_present_indices = np.zeros(input_length + max_new_tokens - 1)
|
||||
request_past_present_indices[:input_length] = 1
|
||||
request_past_present_indices = torch.arange(cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64)
|
||||
past_present_indices.append(request_past_present_indices)
|
||||
|
||||
# Update
|
||||
@ -210,10 +209,20 @@ class FlashCausalLMBatch(Batch):
|
||||
if len(pb.requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = torch.cat(position_ids)
|
||||
|
||||
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
|
||||
|
||||
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
||||
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
position_ids = position_ids[0]
|
||||
|
||||
past_present_indices = past_present_indices[0]
|
||||
|
||||
start_seq_prefill = start_seq
|
||||
end_seq_prefill = end_seq
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
@ -222,19 +231,7 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
past_present_indices = np.concatenate(past_present_indices)
|
||||
|
||||
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
||||
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
||||
else:
|
||||
past_present_indices = past_present_indices[0]
|
||||
|
||||
start_seq_prefill = start_seq
|
||||
end_seq_prefill = end_seq
|
||||
|
||||
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool)
|
||||
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
@ -298,7 +295,7 @@ class FlashCausalLMBatch(Batch):
|
||||
indices = []
|
||||
|
||||
# past indices to keep
|
||||
past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device)
|
||||
past_indices = torch.zeros(self.past_key_values.shape[0], dtype=torch.bool, device=device)
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||
@ -352,7 +349,7 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = self.position_ids[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
past_key_values = self.past_key_values[:, past_indices]
|
||||
past_key_values = self.past_key_values[past_indices]
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
start_seq = start_seq.to(device)
|
||||
@ -409,11 +406,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
end_seq_q = start_seq_q + 1
|
||||
max_seqlen = 0
|
||||
past_key_values = batches[0].past_key_values.new_empty((
|
||||
batches[0].past_key_values.shape[0],
|
||||
total_tokens,
|
||||
*batches[0].past_key_values.shape[2:]
|
||||
))
|
||||
past_key_values = []
|
||||
|
||||
all_input_ids = []
|
||||
|
||||
@ -449,11 +442,6 @@ class FlashCausalLMBatch(Batch):
|
||||
start_seq[start_index:end_index] = batch.start_seq + max_tokens
|
||||
end_seq[start_index:end_index] = batch.end_seq + max_tokens
|
||||
|
||||
past_key_values[
|
||||
:,
|
||||
max_tokens: max_tokens + batch.max_tokens
|
||||
] = batch.past_key_values
|
||||
|
||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
@ -464,6 +452,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
past_key_values.append(batch.past_key_values)
|
||||
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
@ -480,6 +469,7 @@ class FlashCausalLMBatch(Batch):
|
||||
),
|
||||
)
|
||||
|
||||
past_key_values = torch.cat(past_key_values, dim=0)
|
||||
past_present_indices = end_seq - 1
|
||||
|
||||
all_input_ids_tensor = torch.zeros(
|
||||
@ -726,8 +716,8 @@ class FlashCausalLM(Model):
|
||||
# Set values in batch
|
||||
batch.input_ids = next_input_ids
|
||||
batch.position_ids = next_position_ids + 1
|
||||
batch.past_present_indices = torch.clone(batch.end_seq)
|
||||
batch.end_seq += 1
|
||||
batch.past_present_indices = batch.end_seq
|
||||
batch.end_seq = batch.end_seq + 1
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
|
Loading…
Reference in New Issue
Block a user