This commit is contained in:
OlivierDehaene 2023-06-01 10:05:24 +02:00
parent abd58ff82c
commit 5ff2dc9176
5 changed files with 310 additions and 277 deletions

View File

@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from typing import Optional from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda_modif
import dropout_layer_norm import dropout_layer_norm
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
@ -149,7 +149,7 @@ class FlashLlamaAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
@ -175,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
query, query,
layer_past[:, 0], layer_past[:, 0],
layer_past[:, 1], layer_past[:, 1],

View File

@ -28,7 +28,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda_modif
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -134,7 +134,7 @@ class FlashNeoxAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
@ -160,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
query, query,
layer_past[:, 0], layer_past[:, 0],
layer_past[:, 1], layer_past[:, 1],

View File

@ -1,13 +1,14 @@
import torch import torch
import torch.distributed import torch.distributed
from loguru import logger
from torch import nn from torch import nn
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda_modif
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -42,25 +43,25 @@ class RWConfig(PretrainedConfig):
} }
def __init__( def __init__(
self, self,
model_type="RefinedWeb", model_type="RefinedWeb",
vocab_size=250880, vocab_size=250880,
hidden_size=64, hidden_size=64,
n_layer=2, n_layer=2,
n_head=8, n_head=8,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
use_cache=True, use_cache=True,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
n_head_kv=None, n_head_kv=None,
multi_query=False, multi_query=False,
alibi=False, alibi=False,
bias=False, bias=False,
parallel_attn=False, parallel_attn=False,
**kwargs, **kwargs,
): ):
if alibi: if alibi:
raise NotImplementedError( raise NotImplementedError(
@ -126,15 +127,19 @@ class FlashRWAttention(torch.nn.Module):
) )
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
max_s, end_seq,
layer_past, start_seq_q,
layer_past_present_indices, end_seq_q,
cu_seqlens_q, max_s,
layer_past,
layer_past_present_indices,
prefill,
past_stream
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -153,22 +158,26 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(kv[:, 0], cos, sin) self.rotary_emb(kv[:, 0], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past past_stream.wait_stream(torch.cuda.current_stream())
layer_past[...] = kv with torch.cuda.stream(past_stream):
# Copy to layer past
layer_past[layer_past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size) kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
query, query,
kv[:, 0], kv[:, 0],
kv[:, 1], kv[:, 1],
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -181,6 +190,7 @@ class FlashRWAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
torch.cuda.current_stream().wait_stream(past_stream)
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv layer_past[layer_past_present_indices] = kv
# Expand to query shape # Expand to query shape
@ -189,13 +199,15 @@ class FlashRWAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
query, query,
kv[:, 0], kv[:, 0],
kv[:, 1], kv[:, 1],
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
@ -257,15 +269,15 @@ class FlashRWLargeAttention(torch.nn.Module):
) )
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -296,7 +308,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
query, query,
kv[:, :, 0], kv[:, :, 0],
kv[:, :, 1], kv[:, :, 1],
@ -327,7 +339,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
query, query,
kv[:, :, 0], kv[:, :, 0],
kv[:, :, 1], kv[:, :, 1],
@ -412,16 +424,20 @@ class FlashRWLayer(nn.Module):
self.process_group = weights.process_group self.process_group = weights.process_group
def forward( def forward(
self, self,
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
max_s, end_seq,
layer_past, start_seq_q,
layer_past_present_indices, end_seq_q,
cu_seqlens_q, max_s,
layer_past,
layer_past_present_indices,
prefill,
past_stream,
): ):
if self.parallel_attn: if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -430,11 +446,15 @@ class FlashRWLayer(nn.Module):
ln_hidden_states, ln_hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, prefill,
past_stream
) )
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
@ -450,11 +470,14 @@ class FlashRWLayer(nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -493,16 +516,16 @@ class FlashRWLargeLayer(nn.Module):
self.process_group = weights.process_group self.process_group = weights.process_group
def forward( def forward(
self, self,
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
): ):
ln_attn, residual = self.ln_attn(hidden_states, residual) ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual) ln_mlp, _ = self.ln_mlp(residual)
@ -554,6 +577,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.h[0].self_attention.head_size, self.h[0].self_attention.head_size,
) )
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
raise NotImplementedError
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer(layer_id, config, weights) FlashRWLargeLayer(layer_id, config, weights)
@ -577,38 +601,55 @@ class FlashRWModel(FlashRWPreTrainedModel):
) )
self.head_size = self.h[0].self_attention.head_size self.head_size = self.h[0].self_attention.head_size
self.past_stream = torch.cuda.Stream()
def forward( def forward(
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
max_s, start_seq_q,
past_key_values=None, end_seq_q,
pre_allocate_past_size: Optional[int] = None, max_s,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
# Create past tensor assert pre_allocate_past_size is not None
past_key_values = hidden_states.new_empty(
( prefill = True
len(self.h),
len(hidden_states) with torch.cuda.stream(self.past_stream):
if pre_allocate_past_size is None # Create past tensor
else pre_allocate_past_size, past_key_values = hidden_states.new_zeros(
*self.cache_size, (
len(self.h),
pre_allocate_past_size,
*self.cache_size,
)
) )
) seq_indices = []
layer_past_present_indices = None for s, e in zip(start_seq, end_seq):
slice_past_index = len(hidden_states) seq_indices.append(
torch.arange(
s,
e,
dtype=torch.int64,
device=self.device
)
)
layer_past_present_indices = torch.cat(seq_indices)
from loguru import logger
logger.error(f"layer past: {layer_past_present_indices}")
# Decode # Decode
else: else:
prefill = False
# Create indices from cumulative sequence lengths # Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1 layer_past_present_indices = end_seq - 1
slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -618,23 +659,20 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past_key_values, past_key_values[i],
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, prefill,
self.past_stream
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
@ -653,21 +691,25 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
max_s, start_seq_q,
past_key_values: Optional[torch.Tensor] = None, end_seq_q,
pre_allocate_past_size: Optional[int] = None, max_s,
lm_head_indices: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,

View File

@ -6,7 +6,7 @@ from transformers.activations import ACT2FN
from typing import Optional from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda_modif
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -175,7 +175,7 @@ class FlashMQAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
query, query,
key_value[:, 0], key_value[:, 0],
key_value[:, 1], key_value[:, 1],
@ -202,7 +202,7 @@ class FlashMQAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda_modif.fwd(
query, query,
key_value[:, 0], key_value[:, 0],
key_value[:, 1], key_value[:, 1],

View File

@ -34,10 +34,18 @@ class FlashCausalLMBatch(Batch):
input_ids: torch.Tensor input_ids: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
# cumulative sequence lengths # tensor of length b holding starting offset of each sequence
cu_seqlens: torch.Tensor start_seq: torch.Tensor
# cumulative query sequence lengths, only used in decode # tensor of length b holding ending offset of each sequence
cu_seqlens_q: Optional[torch.Tensor] end_seq: torch.Tensor
# tensor of length b holding starting offset of each sequence, only used in prefill
start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding starting offset of each query sequence, only used in decode
start_seq_q: Optional[torch.Tensor]
# tensor of length b holding ending offset of each query sequence, only used in decode
end_seq_q: Optional[torch.Tensor]
# past key values, only used in decode # past key values, only used in decode
past_key_values: Optional[torch.Tensor] past_key_values: Optional[torch.Tensor]
max_seqlen: int max_seqlen: int
@ -73,11 +81,11 @@ class FlashCausalLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
@ -90,7 +98,10 @@ class FlashCausalLMBatch(Batch):
)["input_ids"] )["input_ids"]
position_ids = [] position_ids = []
cu_seqlens = [0] start_seq = []
end_seq = []
start_seq_prefill = []
end_seq_prefill = []
max_seqlen = 0 max_seqlen = 0
input_lengths = [] input_lengths = []
@ -110,9 +121,9 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
cumulative_max_length = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
max_tokens = 0
max_length = 0 max_length = 0
# Parse batch # Parse batch
@ -138,7 +149,10 @@ class FlashCausalLMBatch(Batch):
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length) start_seq_prefill.append(cumulative_length)
end_seq_prefill.append(cumulative_length + input_length)
start_seq.append(cumulative_max_length)
end_seq.append(cumulative_max_length + input_length)
next_token_chooser_parameters.append(r.parameters) next_token_chooser_parameters.append(r.parameters)
@ -169,8 +183,9 @@ class FlashCausalLMBatch(Batch):
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
# Update # Update
# Remove one as the first token des not have a past
cumulative_length += input_length cumulative_length += input_length
max_tokens += input_length + max_new_tokens cumulative_max_length += input_length + max_new_tokens - 1
max_length = max(max_length, input_length + max_new_tokens) max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
@ -197,13 +212,20 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor, dtype=torch.int64, device=device all_input_ids_tensor, dtype=torch.int64, device=device
) )
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1:
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
else:
start_seq_prefill = start_seq
end_seq_prefill = end_seq
if all_prefill_logprobs: if all_prefill_logprobs:
prefill_head_indices = None prefill_head_indices = None
prefill_next_token_indices = cu_seqlens[1:] - 1 prefill_next_token_indices = end_seq - 1
elif no_prefill_logprobs: elif no_prefill_logprobs:
prefill_head_indices = cu_seqlens[1:] - 1 prefill_head_indices = end_seq - 1
prefill_next_token_indices = None prefill_next_token_indices = None
else: else:
prefill_head_indices = torch.tensor( prefill_head_indices = torch.tensor(
@ -219,8 +241,12 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlens=cu_seqlens, start_seq=start_seq,
cu_seqlens_q=None, end_seq=end_seq,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
start_seq_q=None,
end_seq_q=None,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
@ -233,7 +259,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=max_tokens, max_tokens=cumulative_max_length,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -247,7 +273,7 @@ class FlashCausalLMBatch(Batch):
single_request = len(request_ids) == 1 single_request = len(request_ids) == 1
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_max_length = 0
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
@ -256,8 +282,10 @@ class FlashCausalLMBatch(Batch):
indices = [] indices = []
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32) start_seq = torch.empty(len(request_ids), dtype=torch.int32)
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1] end_seq = torch.empty(len(request_ids), dtype=torch.int32)
start_seq_q = self.start_seq_q[: len(request_ids)]
end_seq_q = self.end_seq_q[: len(request_ids)]
max_seqlen = 0 max_seqlen = 0
past_key_values = [] past_key_values = []
@ -270,8 +298,6 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
max_tokens = 0
for i, request_id in enumerate(request_ids): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
indices.append(idx) indices.append(idx)
@ -281,16 +307,8 @@ class FlashCausalLMBatch(Batch):
# Get length # Get length
request_input_length = self.input_lengths[idx] request_input_length = self.input_lengths[idx]
# Copy to tensor (CPU)
cu_seqlens[i + 1] = cumulative_length + request_input_length
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
# Slice from past
past_key_values.append(
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
@ -300,30 +318,19 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
cumulative_length += request_input_length remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
max_tokens += request_input_length + (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens # Copy to tensor (CPU)
start_seq[i] = cumulative_max_length
end_seq[i] = cumulative_max_length + request_input_length
# Slice from past
past_key_values.append(
self.past_key_values[:,
self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1]
) )
if single_request: cumulative_max_length += request_input_length + remaining_tokens - 1
# Preallocate tensor for bs = 1 case
past_key_values = F.pad(
past_key_values[0],
(
0,
0,
0,
0,
0,
0,
0,
stopping_criterias[0].max_new_tokens
- stopping_criterias[0].current_tokens,
),
)
else:
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
@ -331,8 +338,15 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
if single_request:
past_key_values = past_key_values[0]
else:
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device) start_seq = start_seq.to(self.start_seq.device)
end_seq = end_seq.to(self.start_seq.device)
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -340,8 +354,12 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlens=cu_seqlens, start_seq=start_seq,
cu_seqlens_q=cu_seqlens_q, end_seq=end_seq,
start_seq_prefill=None,
end_seq_prefill=None,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
@ -354,7 +372,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=max_tokens, max_tokens=cumulative_max_length,
) )
@classmethod @classmethod
@ -365,18 +383,25 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping = {} requests_idx_mapping = {}
total_batch_size = sum([len(b) for b in batches]) total_batch_size = sum([len(b) for b in batches])
total_tokens = sum(b.max_tokens for b in batches)
dtype = batches[0].past_key_values.dtype dtype = batches[0].past_key_values.dtype
device = batches[0].input_ids.device device = batches[0].input_ids.device
input_ids = batches[0].input_ids.new_empty(total_batch_size) input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
cu_seqlens = [0] start_seq = batches[0].start_seq.new_empty(total_batch_size)
cu_seqlens_q = torch.arange( end_seq = batches[0].end_seq.new_empty(total_batch_size)
0, total_batch_size + 1, device=device, dtype=torch.int32 start_seq_q = torch.arange(
0, total_batch_size, device=device, dtype=torch.int32
) )
end_seq_q = start_seq_q + 1
max_seqlen = 0 max_seqlen = 0
past_key_values = [] past_key_values = batches[0].past_key_values.new_empty((
batches[0].past_key_values.shape[0],
total_tokens,
*batches[0].past_key_values.shape[2:]
))
all_input_ids = [] all_input_ids = []
@ -389,7 +414,6 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0 max_tokens = 0
max_length = 0 max_length = 0
@ -410,18 +434,15 @@ class FlashCausalLMBatch(Batch):
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
# Add cumulative lengths of all previous inputs start_seq[start_index:end_index] = batch.start_seq + max_tokens
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) end_seq[start_index:end_index] = batch.end_seq + max_tokens
max_seqlen = max(max_seqlen, batch.max_seqlen)
if len(batch) != 1: past_key_values[
past_key_values.append(batch.past_key_values) :,
else: max_tokens: max_tokens + batch.max_tokens
# past was pre-allocated for this batch ] = batch.past_key_values
# We need to slice to remove the padding
past_key_values.append( max_seqlen = max(max_seqlen, batch.max_seqlen)
batch.past_key_values[:, : batch.input_lengths[0]]
)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
@ -433,7 +454,6 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
# Update # Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens max_tokens += batch.max_tokens
max_length = max( max_length = max(
@ -458,16 +478,11 @@ class FlashCausalLMBatch(Batch):
end_index = cumulative_batch_size + len(batch) end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length] ] = batch.all_input_ids_tensor[:, :max_length]
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
# Cat past
past_key_values = torch.cat(past_key_values, dim=1)
# Create final tensor on GPU
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device next_token_chooser_parameters, dtype=dtype, device=device
) )
@ -478,8 +493,12 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlens=cu_seqlens, start_seq=start_seq,
cu_seqlens_q=cu_seqlens_q, end_seq=end_seq,
start_seq_prefill=None,
end_seq_prefill=None,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
@ -501,12 +520,12 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model): class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model_cls: Type[PreTrainedModel], model_cls: Type[PreTrainedModel],
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -547,22 +566,26 @@ class FlashCausalLM(Model):
) )
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlens: torch.Tensor, start_seq: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor], end_seq: torch.Tensor,
max_s: int, start_seq_q: Optional[torch.Tensor],
past_key_values: Optional = None, end_seq_q: Optional[torch.Tensor],
pre_allocate_past_size: Optional[int] = None, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlens=cu_seqlens, start_seq=start_seq,
cu_seqlens_q=cu_seqlens_q, end_seq=end_seq,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
max_s=max_s, max_s=max_s,
past_key_values=past_key_values, past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size, pre_allocate_past_size=pre_allocate_past_size,
@ -571,7 +594,7 @@ class FlashCausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None prefill = batch.past_key_values is None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
@ -579,18 +602,22 @@ class FlashCausalLM(Model):
if prefill and single_request: if prefill and single_request:
# Ask to pre-allocate kv to its max size # Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens # == Sum over batch size (number of tokens + max_new_tokens) - batch size
pre_allocate_past_size = ( pre_allocate_past_size = batch.max_tokens
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens start_seq = batch.start_seq_prefill
) end_seq = batch.end_seq_prefill
else: else:
pre_allocate_past_size = None pre_allocate_past_size = None
start_seq = batch.start_seq
end_seq = batch.end_seq
out, present = self.forward( out, present = self.forward(
batch.input_ids, batch.input_ids,
batch.position_ids, batch.position_ids,
batch.cu_seqlens, start_seq,
batch.cu_seqlens_q, end_seq,
batch.start_seq_q,
batch.end_seq_q,
batch.max_seqlen, batch.max_seqlen,
batch.past_key_values, batch.past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
@ -614,55 +641,17 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.cu_seqlens_q for decode # Create batch.start_seq_q and batch.end_seq_q for decode
batch.cu_seqlens_q = torch.arange( batch.start_seq_q = torch.arange(0, len(batch), device=self.device, dtype=torch.int32)
0, len(batch) + 1, device=self.device, dtype=torch.int32 batch.end_seq_q = batch.start_seq_q + 1
)
next_position_ids = batch.position_ids.new_empty(len(batch)) next_position_ids = batch.position_ids.new_empty(len(batch))
# We do not need start_seq_prefill and end_seq_prefill anymore
batch.start_seq_prefill = None
batch.end_seq_prefill = None
else: else:
prefill_logprobs = None prefill_logprobs = None
next_position_ids = batch.position_ids next_position_ids = batch.position_ids
# Prepare past for next decode
if len(batch) > 1:
# Used to slice next batch past
past_indices = torch.empty(
present.shape[1], dtype=torch.int64, device=self.device
)
batch.past_key_values = present.new_empty(
(
present.shape[0],
present.shape[1] + len(batch.requests),
*present.shape[2:],
)
)
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
# and will run asynchronously while we do the next for loop
cumulative_length = 0
for i, input_length in enumerate(batch.input_lengths):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
# Indices to copy present at the correct place in past_key_values
torch.arange(
start_index + i,
end_index + i,
dtype=torch.int64,
device=self.device,
out=past_indices[start_index:end_index],
)
cumulative_length += input_length
# Copy from present to past_key_values
batch.past_key_values[:, past_indices] = present
# Initialize past_key_values in prefill for len(batch) == 1
elif prefill:
# present is already pre-padded
batch.past_key_values = present
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
@ -685,6 +674,7 @@ class FlashCausalLM(Model):
input_length, input_length,
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
@ -718,7 +708,7 @@ class FlashCausalLM(Model):
# Set values in batch # Set values in batch
batch.input_ids = next_input_ids batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1 batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q batch.end_seq += 1
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
@ -787,7 +777,7 @@ class FlashCausalLM(Model):
if stop: if stop:
# Decode generated tokens # Decode generated tokens
output_text = self.decode( output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :] all_input_ids[-stopping_criteria.current_tokens:]
) )
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, output_text,
@ -843,6 +833,7 @@ class FlashCausalLM(Model):
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1 batch.max_seqlen = batch.max_seqlen + 1
batch.past_key_values = present
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None return generations, batch if not stopped else None