mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
wip
This commit is contained in:
parent
abd58ff82c
commit
5ff2dc9176
@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
|
||||
from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
import flash_attn_cuda_modif
|
||||
import dropout_layer_norm
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
@ -149,7 +149,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
@ -175,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
layer_past[:, 0],
|
||||
layer_past[:, 1],
|
||||
|
@ -28,7 +28,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig
|
||||
from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
import flash_attn_cuda_modif
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -134,7 +134,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
@ -160,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
layer_past[:, 0],
|
||||
layer_past[:, 1],
|
||||
|
@ -1,13 +1,14 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from loguru import logger
|
||||
from torch import nn
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
import flash_attn_cuda_modif
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -130,11 +131,15 @@ class FlashRWAttention(torch.nn.Module):
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
prefill,
|
||||
past_stream
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
@ -153,22 +158,26 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.rotary_emb(kv[:, 0], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
if prefill:
|
||||
past_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(past_stream):
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq,
|
||||
end_seq,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
@ -181,6 +190,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
torch.cuda.current_stream().wait_stream(past_stream)
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
@ -189,13 +199,15 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
@ -296,7 +308,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
@ -327,7 +339,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
@ -417,11 +429,15 @@ class FlashRWLayer(nn.Module):
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
prefill,
|
||||
past_stream,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
@ -430,11 +446,15 @@ class FlashRWLayer(nn.Module):
|
||||
ln_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
prefill,
|
||||
past_stream
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(ln_hidden_states)
|
||||
@ -450,11 +470,14 @@ class FlashRWLayer(nn.Module):
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
prefill,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@ -554,6 +577,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
self.h[0].self_attention.head_size,
|
||||
)
|
||||
elif config.model_type == "RefinedWeb":
|
||||
raise NotImplementedError
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLargeLayer(layer_id, config, weights)
|
||||
@ -577,13 +601,16 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
)
|
||||
|
||||
self.head_size = self.h[0].self_attention.head_size
|
||||
self.past_stream = torch.cuda.Stream()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
@ -592,23 +619,37 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
|
||||
# Prefill
|
||||
if past_key_values is None:
|
||||
assert pre_allocate_past_size is not None
|
||||
|
||||
prefill = True
|
||||
|
||||
with torch.cuda.stream(self.past_stream):
|
||||
# Create past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
past_key_values = hidden_states.new_zeros(
|
||||
(
|
||||
len(self.h),
|
||||
len(hidden_states)
|
||||
if pre_allocate_past_size is None
|
||||
else pre_allocate_past_size,
|
||||
pre_allocate_past_size,
|
||||
*self.cache_size,
|
||||
)
|
||||
)
|
||||
layer_past_present_indices = None
|
||||
slice_past_index = len(hidden_states)
|
||||
seq_indices = []
|
||||
for s, e in zip(start_seq, end_seq):
|
||||
seq_indices.append(
|
||||
torch.arange(
|
||||
s,
|
||||
e,
|
||||
dtype=torch.int64,
|
||||
device=self.device
|
||||
)
|
||||
)
|
||||
layer_past_present_indices = torch.cat(seq_indices)
|
||||
from loguru import logger
|
||||
logger.error(f"layer past: {layer_past_present_indices}")
|
||||
# Decode
|
||||
else:
|
||||
prefill = False
|
||||
# Create indices from cumulative sequence lengths
|
||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
||||
slice_past_index = None
|
||||
layer_past_present_indices = end_seq - 1
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
@ -618,23 +659,20 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.h):
|
||||
# We added padding that we now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
else past_key_values[i, :slice_past_index]
|
||||
)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
layer_past_key_values,
|
||||
past_key_values[i],
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
prefill,
|
||||
self.past_stream
|
||||
)
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
@ -656,8 +694,10 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
@ -666,8 +706,10 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
hidden_states, present = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
|
@ -6,7 +6,7 @@ from transformers.activations import ACT2FN
|
||||
from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
import flash_attn_cuda_modif
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -175,7 +175,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
key_value[:, 0],
|
||||
key_value[:, 1],
|
||||
@ -202,7 +202,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
flash_attn_cuda_modif.fwd(
|
||||
query,
|
||||
key_value[:, 0],
|
||||
key_value[:, 1],
|
||||
|
@ -34,10 +34,18 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
|
||||
# cumulative sequence lengths
|
||||
cu_seqlens: torch.Tensor
|
||||
# cumulative query sequence lengths, only used in decode
|
||||
cu_seqlens_q: Optional[torch.Tensor]
|
||||
# tensor of length b holding starting offset of each sequence
|
||||
start_seq: torch.Tensor
|
||||
# tensor of length b holding ending offset of each sequence
|
||||
end_seq: torch.Tensor
|
||||
# tensor of length b holding starting offset of each sequence, only used in prefill
|
||||
start_seq_prefill: Optional[torch.Tensor]
|
||||
# tensor of length b holding ending offset of each sequence, only used in prefill
|
||||
end_seq_prefill: Optional[torch.Tensor]
|
||||
# tensor of length b holding starting offset of each query sequence, only used in decode
|
||||
start_seq_q: Optional[torch.Tensor]
|
||||
# tensor of length b holding ending offset of each query sequence, only used in decode
|
||||
end_seq_q: Optional[torch.Tensor]
|
||||
# past key values, only used in decode
|
||||
past_key_values: Optional[torch.Tensor]
|
||||
max_seqlen: int
|
||||
@ -90,7 +98,10 @@ class FlashCausalLMBatch(Batch):
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
cu_seqlens = [0]
|
||||
start_seq = []
|
||||
end_seq = []
|
||||
start_seq_prefill = []
|
||||
end_seq_prefill = []
|
||||
max_seqlen = 0
|
||||
|
||||
input_lengths = []
|
||||
@ -110,9 +121,9 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_max_length = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
max_tokens = 0
|
||||
max_length = 0
|
||||
|
||||
# Parse batch
|
||||
@ -138,7 +149,10 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.append(cumulative_length + input_length)
|
||||
start_seq_prefill.append(cumulative_length)
|
||||
end_seq_prefill.append(cumulative_length + input_length)
|
||||
start_seq.append(cumulative_max_length)
|
||||
end_seq.append(cumulative_max_length + input_length)
|
||||
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
@ -169,8 +183,9 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
# Update
|
||||
# Remove one as the first token des not have a past
|
||||
cumulative_length += input_length
|
||||
max_tokens += input_length + max_new_tokens
|
||||
cumulative_max_length += input_length + max_new_tokens - 1
|
||||
max_length = max(max_length, input_length + max_new_tokens)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
@ -197,13 +212,20 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||
if len(pb.requests) > 1:
|
||||
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
||||
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
||||
else:
|
||||
start_seq_prefill = start_seq
|
||||
end_seq_prefill = end_seq
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlens[1:] - 1
|
||||
prefill_next_token_indices = end_seq - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = cu_seqlens[1:] - 1
|
||||
prefill_head_indices = end_seq - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
@ -219,8 +241,12 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_seqlens_q=None,
|
||||
start_seq=start_seq,
|
||||
end_seq=end_seq,
|
||||
start_seq_prefill=start_seq_prefill,
|
||||
end_seq_prefill=end_seq_prefill,
|
||||
start_seq_q=None,
|
||||
end_seq_q=None,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
@ -233,7 +259,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
max_tokens=cumulative_max_length,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
@ -247,7 +273,7 @@ class FlashCausalLMBatch(Batch):
|
||||
single_request = len(request_ids) == 1
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_max_length = 0
|
||||
|
||||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
@ -256,8 +282,10 @@ class FlashCausalLMBatch(Batch):
|
||||
indices = []
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
|
||||
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
|
||||
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||
start_seq_q = self.start_seq_q[: len(request_ids)]
|
||||
end_seq_q = self.end_seq_q[: len(request_ids)]
|
||||
max_seqlen = 0
|
||||
past_key_values = []
|
||||
|
||||
@ -270,8 +298,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
stopping_criterias = []
|
||||
|
||||
max_tokens = 0
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
@ -281,16 +307,8 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Get length
|
||||
request_input_length = self.input_lengths[idx]
|
||||
|
||||
# Copy to tensor (CPU)
|
||||
cu_seqlens[i + 1] = cumulative_length + request_input_length
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
|
||||
# Slice from past
|
||||
past_key_values.append(
|
||||
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
|
||||
)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
|
||||
input_lengths.append(request_input_length)
|
||||
@ -300,30 +318,19 @@ class FlashCausalLMBatch(Batch):
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
cumulative_length += request_input_length
|
||||
max_tokens += request_input_length + (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
|
||||
# Copy to tensor (CPU)
|
||||
start_seq[i] = cumulative_max_length
|
||||
end_seq[i] = cumulative_max_length + request_input_length
|
||||
|
||||
# Slice from past
|
||||
past_key_values.append(
|
||||
self.past_key_values[:,
|
||||
self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1]
|
||||
)
|
||||
|
||||
if single_request:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
past_key_values = F.pad(
|
||||
past_key_values[0],
|
||||
(
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[0].current_tokens,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Cat all past
|
||||
past_key_values = torch.cat(past_key_values, dim=1)
|
||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
@ -331,8 +338,15 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
|
||||
if single_request:
|
||||
past_key_values = past_key_values[0]
|
||||
else:
|
||||
# Cat all past
|
||||
past_key_values = torch.cat(past_key_values, dim=1)
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
|
||||
start_seq = start_seq.to(self.start_seq.device)
|
||||
end_seq = end_seq.to(self.start_seq.device)
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=self.batch_id,
|
||||
@ -340,8 +354,12 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
start_seq=start_seq,
|
||||
end_seq=end_seq,
|
||||
start_seq_prefill=None,
|
||||
end_seq_prefill=None,
|
||||
start_seq_q=start_seq_q,
|
||||
end_seq_q=end_seq_q,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
@ -354,7 +372,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
max_tokens=cumulative_max_length,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -365,18 +383,25 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping = {}
|
||||
|
||||
total_batch_size = sum([len(b) for b in batches])
|
||||
total_tokens = sum(b.max_tokens for b in batches)
|
||||
|
||||
dtype = batches[0].past_key_values.dtype
|
||||
device = batches[0].input_ids.device
|
||||
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
cu_seqlens = [0]
|
||||
cu_seqlens_q = torch.arange(
|
||||
0, total_batch_size + 1, device=device, dtype=torch.int32
|
||||
start_seq = batches[0].start_seq.new_empty(total_batch_size)
|
||||
end_seq = batches[0].end_seq.new_empty(total_batch_size)
|
||||
start_seq_q = torch.arange(
|
||||
0, total_batch_size, device=device, dtype=torch.int32
|
||||
)
|
||||
end_seq_q = start_seq_q + 1
|
||||
max_seqlen = 0
|
||||
past_key_values = []
|
||||
past_key_values = batches[0].past_key_values.new_empty((
|
||||
batches[0].past_key_values.shape[0],
|
||||
total_tokens,
|
||||
*batches[0].past_key_values.shape[2:]
|
||||
))
|
||||
|
||||
all_input_ids = []
|
||||
|
||||
@ -389,7 +414,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
cumulative_length = 0
|
||||
max_tokens = 0
|
||||
max_length = 0
|
||||
|
||||
@ -410,18 +434,15 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
|
||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||
start_seq[start_index:end_index] = batch.start_seq + max_tokens
|
||||
end_seq[start_index:end_index] = batch.end_seq + max_tokens
|
||||
|
||||
if len(batch) != 1:
|
||||
past_key_values.append(batch.past_key_values)
|
||||
else:
|
||||
# past was pre-allocated for this batch
|
||||
# We need to slice to remove the padding
|
||||
past_key_values.append(
|
||||
batch.past_key_values[:, : batch.input_lengths[0]]
|
||||
)
|
||||
past_key_values[
|
||||
:,
|
||||
max_tokens: max_tokens + batch.max_tokens
|
||||
] = batch.past_key_values
|
||||
|
||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
|
||||
@ -433,7 +454,6 @@ class FlashCausalLMBatch(Batch):
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
# Update
|
||||
cumulative_length += batch.cu_seqlens[-1]
|
||||
cumulative_batch_size += len(batch)
|
||||
max_tokens += batch.max_tokens
|
||||
max_length = max(
|
||||
@ -463,11 +483,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
# Cat past
|
||||
past_key_values = torch.cat(past_key_values, dim=1)
|
||||
# Create final tensor on GPU
|
||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype=dtype, device=device
|
||||
)
|
||||
@ -478,8 +493,12 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
start_seq=start_seq,
|
||||
end_seq=end_seq,
|
||||
start_seq_prefill=None,
|
||||
end_seq_prefill=None,
|
||||
start_seq_q=start_seq_q,
|
||||
end_seq_q=end_seq_q,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
@ -550,8 +569,10 @@ class FlashCausalLM(Model):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor],
|
||||
start_seq: torch.Tensor,
|
||||
end_seq: torch.Tensor,
|
||||
start_seq_q: Optional[torch.Tensor],
|
||||
end_seq_q: Optional[torch.Tensor],
|
||||
max_s: int,
|
||||
past_key_values: Optional = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
@ -561,8 +582,10 @@ class FlashCausalLM(Model):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
start_seq=start_seq,
|
||||
end_seq=end_seq,
|
||||
start_seq_q=start_seq_q,
|
||||
end_seq_q=end_seq_q,
|
||||
max_s=max_s,
|
||||
past_key_values=past_key_values,
|
||||
pre_allocate_past_size=pre_allocate_past_size,
|
||||
@ -579,18 +602,22 @@ class FlashCausalLM(Model):
|
||||
|
||||
if prefill and single_request:
|
||||
# Ask to pre-allocate kv to its max size
|
||||
# == number of tokens + max_new_tokens
|
||||
pre_allocate_past_size = (
|
||||
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
|
||||
pre_allocate_past_size = batch.max_tokens
|
||||
start_seq = batch.start_seq_prefill
|
||||
end_seq = batch.end_seq_prefill
|
||||
else:
|
||||
pre_allocate_past_size = None
|
||||
start_seq = batch.start_seq
|
||||
end_seq = batch.end_seq
|
||||
|
||||
out, present = self.forward(
|
||||
batch.input_ids,
|
||||
batch.position_ids,
|
||||
batch.cu_seqlens,
|
||||
batch.cu_seqlens_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
batch.start_seq_q,
|
||||
batch.end_seq_q,
|
||||
batch.max_seqlen,
|
||||
batch.past_key_values,
|
||||
pre_allocate_past_size,
|
||||
@ -614,55 +641,17 @@ class FlashCausalLM(Model):
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
# Create batch.cu_seqlens_q for decode
|
||||
batch.cu_seqlens_q = torch.arange(
|
||||
0, len(batch) + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
# Create batch.start_seq_q and batch.end_seq_q for decode
|
||||
batch.start_seq_q = torch.arange(0, len(batch), device=self.device, dtype=torch.int32)
|
||||
batch.end_seq_q = batch.start_seq_q + 1
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
# We do not need start_seq_prefill and end_seq_prefill anymore
|
||||
batch.start_seq_prefill = None
|
||||
batch.end_seq_prefill = None
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_position_ids = batch.position_ids
|
||||
|
||||
# Prepare past for next decode
|
||||
if len(batch) > 1:
|
||||
# Used to slice next batch past
|
||||
past_indices = torch.empty(
|
||||
present.shape[1], dtype=torch.int64, device=self.device
|
||||
)
|
||||
batch.past_key_values = present.new_empty(
|
||||
(
|
||||
present.shape[0],
|
||||
present.shape[1] + len(batch.requests),
|
||||
*present.shape[2:],
|
||||
)
|
||||
)
|
||||
|
||||
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
|
||||
# and will run asynchronously while we do the next for loop
|
||||
cumulative_length = 0
|
||||
for i, input_length in enumerate(batch.input_lengths):
|
||||
# Indexing metadata
|
||||
start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
# Indices to copy present at the correct place in past_key_values
|
||||
torch.arange(
|
||||
start_index + i,
|
||||
end_index + i,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
out=past_indices[start_index:end_index],
|
||||
)
|
||||
cumulative_length += input_length
|
||||
|
||||
# Copy from present to past_key_values
|
||||
batch.past_key_values[:, past_indices] = present
|
||||
|
||||
# Initialize past_key_values in prefill for len(batch) == 1
|
||||
elif prefill:
|
||||
# present is already pre-padded
|
||||
batch.past_key_values = present
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
||||
@ -685,6 +674,7 @@ class FlashCausalLM(Model):
|
||||
input_length,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Indexing metadata
|
||||
start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
@ -718,7 +708,7 @@ class FlashCausalLM(Model):
|
||||
# Set values in batch
|
||||
batch.input_ids = next_input_ids
|
||||
batch.position_ids = next_position_ids + 1
|
||||
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
|
||||
batch.end_seq += 1
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
@ -843,6 +833,7 @@ class FlashCausalLM(Model):
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
batch.past_key_values = present
|
||||
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
return generations, batch if not stopped else None
|
||||
|
Loading…
Reference in New Issue
Block a user