working rw 7b

This commit is contained in:
OlivierDehaene 2023-06-01 13:32:48 +02:00
parent 5ff2dc9176
commit c9e7471742
2 changed files with 47 additions and 50 deletions

View File

@ -1,7 +1,6 @@
import torch
import torch.distributed
from loguru import logger
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
@ -139,7 +138,6 @@ class FlashRWAttention(torch.nn.Module):
layer_past,
layer_past_present_indices,
prefill,
past_stream
):
qkv = self.query_key_value(hidden_states)
@ -159,10 +157,8 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if prefill:
past_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(past_stream):
# Copy to layer past
layer_past[layer_past_present_indices] = kv
# Copy to layer past
layer_past[layer_past_present_indices] = kv
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
@ -190,7 +186,6 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
torch.cuda.current_stream().wait_stream(past_stream)
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
# Expand to query shape
@ -437,7 +432,6 @@ class FlashRWLayer(nn.Module):
layer_past,
layer_past_present_indices,
prefill,
past_stream,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -454,7 +448,6 @@ class FlashRWLayer(nn.Module):
layer_past,
layer_past_present_indices,
prefill,
past_stream
)
mlp_output = self.mlp(ln_hidden_states)
@ -601,7 +594,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
)
self.head_size = self.h[0].self_attention.head_size
self.past_stream = torch.cuda.Stream()
def forward(
self,
@ -612,6 +604,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
@ -623,33 +616,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
prefill = True
with torch.cuda.stream(self.past_stream):
# Create past tensor
past_key_values = hidden_states.new_zeros(
(
len(self.h),
pre_allocate_past_size,
*self.cache_size,
)
# Create past tensor
past_key_values = hidden_states.new_zeros(
(
len(self.h),
pre_allocate_past_size,
*self.cache_size,
)
seq_indices = []
for s, e in zip(start_seq, end_seq):
seq_indices.append(
torch.arange(
s,
e,
dtype=torch.int64,
device=self.device
)
)
layer_past_present_indices = torch.cat(seq_indices)
from loguru import logger
logger.error(f"layer past: {layer_past_present_indices}")
)
# Decode
else:
prefill = False
# Create indices from cumulative sequence lengths
layer_past_present_indices = end_seq - 1
# Get rotary cos and sin for this forward
# Avoid to index in each layer
@ -670,9 +647,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
end_seq_q,
max_s,
past_key_values[i],
layer_past_present_indices,
past_present_indices,
prefill,
self.past_stream
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -699,6 +675,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -711,6 +688,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)

View File

@ -34,6 +34,9 @@ class FlashCausalLMBatch(Batch):
input_ids: torch.Tensor
position_ids: torch.Tensor
# Indices to copy present to the correct indices is the pre-allocated past key values
past_present_indices: torch.Tensor
# tensor of length b holding starting offset of each sequence
start_seq: torch.Tensor
# tensor of length b holding ending offset of each sequence
@ -98,6 +101,7 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
position_ids = []
past_present_indices = []
start_seq = []
end_seq = []
start_seq_prefill = []
@ -182,6 +186,10 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
request_past_present_indices = np.zeros(input_length + max_new_tokens - 1)
request_past_present_indices[:input_length] = 1
past_present_indices.append(request_past_present_indices)
# Update
# Remove one as the first token des not have a past
cumulative_length += input_length
@ -214,13 +222,20 @@ class FlashCausalLMBatch(Batch):
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1:
past_present_indices = np.concatenate(past_present_indices)
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
else:
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = end_seq - 1
@ -241,6 +256,7 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=start_seq_prefill,
@ -270,7 +286,7 @@ class FlashCausalLMBatch(Batch):
if len(request_ids) == len(self):
return self
single_request = len(request_ids) == 1
device = self.input_ids.device
# Cumulative length
cumulative_max_length = 0
@ -281,13 +297,15 @@ class FlashCausalLMBatch(Batch):
# Used to index into tensors
indices = []
# past indices to keep
past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device)
# Create on CPU to only move to GPU once instead of at every copy
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
start_seq_q = self.start_seq_q[: len(request_ids)]
end_seq_q = self.end_seq_q[: len(request_ids)]
max_seqlen = 0
past_key_values = []
requests = []
all_input_ids = []
@ -324,11 +342,8 @@ class FlashCausalLMBatch(Batch):
start_seq[i] = cumulative_max_length
end_seq[i] = cumulative_max_length + request_input_length
# Slice from past
past_key_values.append(
self.past_key_values[:,
self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1]
)
# Set slice
past_indices[self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1] = True
cumulative_max_length += request_input_length + remaining_tokens - 1
@ -337,16 +352,12 @@ class FlashCausalLMBatch(Batch):
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
if single_request:
past_key_values = past_key_values[0]
else:
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
past_key_values = self.past_key_values[:, past_indices]
# Move to GPU now that we have the whole tensor
start_seq = start_seq.to(self.start_seq.device)
end_seq = end_seq.to(self.start_seq.device)
start_seq = start_seq.to(device)
end_seq = end_seq.to(device)
past_present_indices = end_seq - 1
return FlashCausalLMBatch(
batch_id=self.batch_id,
@ -354,6 +365,7 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None,
@ -468,6 +480,8 @@ class FlashCausalLMBatch(Batch):
),
)
past_present_indices = end_seq - 1
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
@ -493,6 +507,7 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None,
@ -574,6 +589,7 @@ class FlashCausalLM(Model):
start_seq_q: Optional[torch.Tensor],
end_seq_q: Optional[torch.Tensor],
max_s: int,
past_present_indices: torch.Tensor,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -587,6 +603,7 @@ class FlashCausalLM(Model):
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
max_s=max_s,
past_present_indices=past_present_indices,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices,
@ -619,6 +636,7 @@ class FlashCausalLM(Model):
batch.start_seq_q,
batch.end_seq_q,
batch.max_seqlen,
batch.past_present_indices,
batch.past_key_values,
pre_allocate_past_size,
batch.prefill_head_indices,
@ -708,6 +726,7 @@ class FlashCausalLM(Model):
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.past_present_indices = torch.clone(batch.end_seq)
batch.end_seq += 1
if prefill and prefill_logprobs: