mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
working rw 7b
This commit is contained in:
parent
5ff2dc9176
commit
c9e7471742
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from loguru import logger
|
||||
from torch import nn
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@ -139,7 +138,6 @@ class FlashRWAttention(torch.nn.Module):
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
prefill,
|
||||
past_stream
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
@ -159,10 +157,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
past_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(past_stream):
|
||||
# Copy to layer past
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
# Copy to layer past
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
@ -190,7 +186,6 @@ class FlashRWAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
torch.cuda.current_stream().wait_stream(past_stream)
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
@ -437,7 +432,6 @@ class FlashRWLayer(nn.Module):
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
prefill,
|
||||
past_stream,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
@ -454,7 +448,6 @@ class FlashRWLayer(nn.Module):
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
prefill,
|
||||
past_stream
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(ln_hidden_states)
|
||||
@ -601,7 +594,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
)
|
||||
|
||||
self.head_size = self.h[0].self_attention.head_size
|
||||
self.past_stream = torch.cuda.Stream()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -612,6 +604,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
@ -623,33 +616,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
|
||||
prefill = True
|
||||
|
||||
with torch.cuda.stream(self.past_stream):
|
||||
# Create past tensor
|
||||
past_key_values = hidden_states.new_zeros(
|
||||
(
|
||||
len(self.h),
|
||||
pre_allocate_past_size,
|
||||
*self.cache_size,
|
||||
)
|
||||
# Create past tensor
|
||||
past_key_values = hidden_states.new_zeros(
|
||||
(
|
||||
len(self.h),
|
||||
pre_allocate_past_size,
|
||||
*self.cache_size,
|
||||
)
|
||||
seq_indices = []
|
||||
for s, e in zip(start_seq, end_seq):
|
||||
seq_indices.append(
|
||||
torch.arange(
|
||||
s,
|
||||
e,
|
||||
dtype=torch.int64,
|
||||
device=self.device
|
||||
)
|
||||
)
|
||||
layer_past_present_indices = torch.cat(seq_indices)
|
||||
from loguru import logger
|
||||
logger.error(f"layer past: {layer_past_present_indices}")
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
prefill = False
|
||||
# Create indices from cumulative sequence lengths
|
||||
layer_past_present_indices = end_seq - 1
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
@ -670,9 +647,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_key_values[i],
|
||||
layer_past_present_indices,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
self.past_stream
|
||||
)
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
@ -699,6 +675,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -711,6 +688,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
|
@ -34,6 +34,9 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
|
||||
# Indices to copy present to the correct indices is the pre-allocated past key values
|
||||
past_present_indices: torch.Tensor
|
||||
|
||||
# tensor of length b holding starting offset of each sequence
|
||||
start_seq: torch.Tensor
|
||||
# tensor of length b holding ending offset of each sequence
|
||||
@ -98,6 +101,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
past_present_indices = []
|
||||
start_seq = []
|
||||
end_seq = []
|
||||
start_seq_prefill = []
|
||||
@ -182,6 +186,10 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
request_past_present_indices = np.zeros(input_length + max_new_tokens - 1)
|
||||
request_past_present_indices[:input_length] = 1
|
||||
past_present_indices.append(request_past_present_indices)
|
||||
|
||||
# Update
|
||||
# Remove one as the first token des not have a past
|
||||
cumulative_length += input_length
|
||||
@ -214,13 +222,20 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
past_present_indices = np.concatenate(past_present_indices)
|
||||
|
||||
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
||||
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
||||
else:
|
||||
past_present_indices = past_present_indices[0]
|
||||
|
||||
start_seq_prefill = start_seq
|
||||
end_seq_prefill = end_seq
|
||||
|
||||
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = end_seq - 1
|
||||
@ -241,6 +256,7 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
past_present_indices=past_present_indices,
|
||||
start_seq=start_seq,
|
||||
end_seq=end_seq,
|
||||
start_seq_prefill=start_seq_prefill,
|
||||
@ -270,7 +286,7 @@ class FlashCausalLMBatch(Batch):
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
|
||||
single_request = len(request_ids) == 1
|
||||
device = self.input_ids.device
|
||||
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
@ -281,13 +297,15 @@ class FlashCausalLMBatch(Batch):
|
||||
# Used to index into tensors
|
||||
indices = []
|
||||
|
||||
# past indices to keep
|
||||
past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device)
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||
start_seq_q = self.start_seq_q[: len(request_ids)]
|
||||
end_seq_q = self.end_seq_q[: len(request_ids)]
|
||||
max_seqlen = 0
|
||||
past_key_values = []
|
||||
|
||||
requests = []
|
||||
all_input_ids = []
|
||||
@ -324,11 +342,8 @@ class FlashCausalLMBatch(Batch):
|
||||
start_seq[i] = cumulative_max_length
|
||||
end_seq[i] = cumulative_max_length + request_input_length
|
||||
|
||||
# Slice from past
|
||||
past_key_values.append(
|
||||
self.past_key_values[:,
|
||||
self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1]
|
||||
)
|
||||
# Set slice
|
||||
past_indices[self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1] = True
|
||||
|
||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||
|
||||
@ -337,16 +352,12 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = self.position_ids[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
|
||||
if single_request:
|
||||
past_key_values = past_key_values[0]
|
||||
else:
|
||||
# Cat all past
|
||||
past_key_values = torch.cat(past_key_values, dim=1)
|
||||
past_key_values = self.past_key_values[:, past_indices]
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
start_seq = start_seq.to(self.start_seq.device)
|
||||
end_seq = end_seq.to(self.start_seq.device)
|
||||
start_seq = start_seq.to(device)
|
||||
end_seq = end_seq.to(device)
|
||||
past_present_indices = end_seq - 1
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=self.batch_id,
|
||||
@ -354,6 +365,7 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
past_present_indices=past_present_indices,
|
||||
start_seq=start_seq,
|
||||
end_seq=end_seq,
|
||||
start_seq_prefill=None,
|
||||
@ -468,6 +480,8 @@ class FlashCausalLMBatch(Batch):
|
||||
),
|
||||
)
|
||||
|
||||
past_present_indices = end_seq - 1
|
||||
|
||||
all_input_ids_tensor = torch.zeros(
|
||||
(total_batch_size, max_length), dtype=torch.int64, device=device
|
||||
)
|
||||
@ -493,6 +507,7 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
past_present_indices=past_present_indices,
|
||||
start_seq=start_seq,
|
||||
end_seq=end_seq,
|
||||
start_seq_prefill=None,
|
||||
@ -574,6 +589,7 @@ class FlashCausalLM(Model):
|
||||
start_seq_q: Optional[torch.Tensor],
|
||||
end_seq_q: Optional[torch.Tensor],
|
||||
max_s: int,
|
||||
past_present_indices: torch.Tensor,
|
||||
past_key_values: Optional = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -587,6 +603,7 @@ class FlashCausalLM(Model):
|
||||
start_seq_q=start_seq_q,
|
||||
end_seq_q=end_seq_q,
|
||||
max_s=max_s,
|
||||
past_present_indices=past_present_indices,
|
||||
past_key_values=past_key_values,
|
||||
pre_allocate_past_size=pre_allocate_past_size,
|
||||
lm_head_indices=lm_head_indices,
|
||||
@ -619,6 +636,7 @@ class FlashCausalLM(Model):
|
||||
batch.start_seq_q,
|
||||
batch.end_seq_q,
|
||||
batch.max_seqlen,
|
||||
batch.past_present_indices,
|
||||
batch.past_key_values,
|
||||
pre_allocate_past_size,
|
||||
batch.prefill_head_indices,
|
||||
@ -708,6 +726,7 @@ class FlashCausalLM(Model):
|
||||
# Set values in batch
|
||||
batch.input_ids = next_input_ids
|
||||
batch.position_ids = next_position_ids + 1
|
||||
batch.past_present_indices = torch.clone(batch.end_seq)
|
||||
batch.end_seq += 1
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
|
Loading…
Reference in New Issue
Block a user