mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
working rw 7b
This commit is contained in:
parent
5ff2dc9176
commit
c9e7471742
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
@ -139,7 +138,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
past_stream
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
@ -159,10 +157,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
past_stream.wait_stream(torch.cuda.current_stream())
|
# Copy to layer past
|
||||||
with torch.cuda.stream(past_stream):
|
layer_past[layer_past_present_indices] = kv
|
||||||
# Copy to layer past
|
|
||||||
layer_past[layer_past_present_indices] = kv
|
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
@ -190,7 +186,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
torch.cuda.current_stream().wait_stream(past_stream)
|
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[layer_past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
@ -437,7 +432,6 @@ class FlashRWLayer(nn.Module):
|
|||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
past_stream,
|
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
@ -454,7 +448,6 @@ class FlashRWLayer(nn.Module):
|
|||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
layer_past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
past_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(ln_hidden_states)
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
@ -601,7 +594,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.head_size = self.h[0].self_attention.head_size
|
self.head_size = self.h[0].self_attention.head_size
|
||||||
self.past_stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -612,6 +604,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@ -623,33 +616,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
prefill = True
|
prefill = True
|
||||||
|
|
||||||
with torch.cuda.stream(self.past_stream):
|
# Create past tensor
|
||||||
# Create past tensor
|
past_key_values = hidden_states.new_zeros(
|
||||||
past_key_values = hidden_states.new_zeros(
|
(
|
||||||
(
|
len(self.h),
|
||||||
len(self.h),
|
pre_allocate_past_size,
|
||||||
pre_allocate_past_size,
|
*self.cache_size,
|
||||||
*self.cache_size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
seq_indices = []
|
)
|
||||||
for s, e in zip(start_seq, end_seq):
|
|
||||||
seq_indices.append(
|
|
||||||
torch.arange(
|
|
||||||
s,
|
|
||||||
e,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=self.device
|
|
||||||
)
|
|
||||||
)
|
|
||||||
layer_past_present_indices = torch.cat(seq_indices)
|
|
||||||
from loguru import logger
|
|
||||||
logger.error(f"layer past: {layer_past_present_indices}")
|
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
prefill = False
|
prefill = False
|
||||||
# Create indices from cumulative sequence lengths
|
|
||||||
layer_past_present_indices = end_seq - 1
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -670,9 +647,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[i],
|
past_key_values[i],
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
self.past_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
@ -699,6 +675,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -711,6 +688,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
@ -34,6 +34,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
|
|
||||||
|
# Indices to copy present to the correct indices is the pre-allocated past key values
|
||||||
|
past_present_indices: torch.Tensor
|
||||||
|
|
||||||
# tensor of length b holding starting offset of each sequence
|
# tensor of length b holding starting offset of each sequence
|
||||||
start_seq: torch.Tensor
|
start_seq: torch.Tensor
|
||||||
# tensor of length b holding ending offset of each sequence
|
# tensor of length b holding ending offset of each sequence
|
||||||
@ -98,6 +101,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
|
past_present_indices = []
|
||||||
start_seq = []
|
start_seq = []
|
||||||
end_seq = []
|
end_seq = []
|
||||||
start_seq_prefill = []
|
start_seq_prefill = []
|
||||||
@ -182,6 +186,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
|
request_past_present_indices = np.zeros(input_length + max_new_tokens - 1)
|
||||||
|
request_past_present_indices[:input_length] = 1
|
||||||
|
past_present_indices.append(request_past_present_indices)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
@ -214,13 +222,20 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
if len(pb.requests) > 1:
|
||||||
|
past_present_indices = np.concatenate(past_present_indices)
|
||||||
|
|
||||||
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
||||||
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
||||||
else:
|
else:
|
||||||
|
past_present_indices = past_present_indices[0]
|
||||||
|
|
||||||
start_seq_prefill = start_seq
|
start_seq_prefill = start_seq
|
||||||
end_seq_prefill = end_seq
|
end_seq_prefill = end_seq
|
||||||
|
|
||||||
|
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool)
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
prefill_next_token_indices = end_seq - 1
|
prefill_next_token_indices = end_seq - 1
|
||||||
@ -241,6 +256,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
past_present_indices=past_present_indices,
|
||||||
start_seq=start_seq,
|
start_seq=start_seq,
|
||||||
end_seq=end_seq,
|
end_seq=end_seq,
|
||||||
start_seq_prefill=start_seq_prefill,
|
start_seq_prefill=start_seq_prefill,
|
||||||
@ -270,7 +286,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
single_request = len(request_ids) == 1
|
device = self.input_ids.device
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
@ -281,13 +297,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Used to index into tensors
|
# Used to index into tensors
|
||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
|
# past indices to keep
|
||||||
|
past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||||
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||||
start_seq_q = self.start_seq_q[: len(request_ids)]
|
start_seq_q = self.start_seq_q[: len(request_ids)]
|
||||||
end_seq_q = self.end_seq_q[: len(request_ids)]
|
end_seq_q = self.end_seq_q[: len(request_ids)]
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
past_key_values = []
|
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
@ -324,11 +342,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_seq[i] = cumulative_max_length
|
start_seq[i] = cumulative_max_length
|
||||||
end_seq[i] = cumulative_max_length + request_input_length
|
end_seq[i] = cumulative_max_length + request_input_length
|
||||||
|
|
||||||
# Slice from past
|
# Set slice
|
||||||
past_key_values.append(
|
past_indices[self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1] = True
|
||||||
self.past_key_values[:,
|
|
||||||
self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1]
|
|
||||||
)
|
|
||||||
|
|
||||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||||
|
|
||||||
@ -337,16 +352,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
|
past_key_values = self.past_key_values[:, past_indices]
|
||||||
if single_request:
|
|
||||||
past_key_values = past_key_values[0]
|
|
||||||
else:
|
|
||||||
# Cat all past
|
|
||||||
past_key_values = torch.cat(past_key_values, dim=1)
|
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
start_seq = start_seq.to(self.start_seq.device)
|
start_seq = start_seq.to(device)
|
||||||
end_seq = end_seq.to(self.start_seq.device)
|
end_seq = end_seq.to(device)
|
||||||
|
past_present_indices = end_seq - 1
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
@ -354,6 +365,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
past_present_indices=past_present_indices,
|
||||||
start_seq=start_seq,
|
start_seq=start_seq,
|
||||||
end_seq=end_seq,
|
end_seq=end_seq,
|
||||||
start_seq_prefill=None,
|
start_seq_prefill=None,
|
||||||
@ -468,6 +480,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
past_present_indices = end_seq - 1
|
||||||
|
|
||||||
all_input_ids_tensor = torch.zeros(
|
all_input_ids_tensor = torch.zeros(
|
||||||
(total_batch_size, max_length), dtype=torch.int64, device=device
|
(total_batch_size, max_length), dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
@ -493,6 +507,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
past_present_indices=past_present_indices,
|
||||||
start_seq=start_seq,
|
start_seq=start_seq,
|
||||||
end_seq=end_seq,
|
end_seq=end_seq,
|
||||||
start_seq_prefill=None,
|
start_seq_prefill=None,
|
||||||
@ -574,6 +589,7 @@ class FlashCausalLM(Model):
|
|||||||
start_seq_q: Optional[torch.Tensor],
|
start_seq_q: Optional[torch.Tensor],
|
||||||
end_seq_q: Optional[torch.Tensor],
|
end_seq_q: Optional[torch.Tensor],
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
past_present_indices: torch.Tensor,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -587,6 +603,7 @@ class FlashCausalLM(Model):
|
|||||||
start_seq_q=start_seq_q,
|
start_seq_q=start_seq_q,
|
||||||
end_seq_q=end_seq_q,
|
end_seq_q=end_seq_q,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
|
past_present_indices=past_present_indices,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
pre_allocate_past_size=pre_allocate_past_size,
|
pre_allocate_past_size=pre_allocate_past_size,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
@ -619,6 +636,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.start_seq_q,
|
batch.start_seq_q,
|
||||||
batch.end_seq_q,
|
batch.end_seq_q,
|
||||||
batch.max_seqlen,
|
batch.max_seqlen,
|
||||||
|
batch.past_present_indices,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
batch.prefill_head_indices,
|
batch.prefill_head_indices,
|
||||||
@ -708,6 +726,7 @@ class FlashCausalLM(Model):
|
|||||||
# Set values in batch
|
# Set values in batch
|
||||||
batch.input_ids = next_input_ids
|
batch.input_ids = next_input_ids
|
||||||
batch.position_ids = next_position_ids + 1
|
batch.position_ids = next_position_ids + 1
|
||||||
|
batch.past_present_indices = torch.clone(batch.end_seq)
|
||||||
batch.end_seq += 1
|
batch.end_seq += 1
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
|
Loading…
Reference in New Issue
Block a user