fix(server): fix past key values logic (#216)

@njhill fyi
This commit is contained in:
OlivierDehaene 2023-04-21 15:59:18 +02:00 committed by GitHub
parent 343437c7b5
commit db4cb5e4ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 123 additions and 20 deletions

View File

@ -25,6 +25,7 @@ from torch.nn import functional as F
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional
# Flash attention imports # Flash attention imports
import rotary_emb import rotary_emb
@ -554,7 +555,8 @@ class FlashLlamaModel(torch.nn.Module):
position_ids, position_ids,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values=None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -564,7 +566,9 @@ class FlashLlamaModel(torch.nn.Module):
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_empty(
( (
len(self.layers), len(self.layers),
len(hidden_states), len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2, 2,
self.num_heads, self.num_heads,
self.head_size, self.head_size,
@ -572,6 +576,7 @@ class FlashLlamaModel(torch.nn.Module):
) )
layer_past_present_indices = None layer_past_present_indices = None
cu_seqlens_q = None cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths # Create indices from cumulative sequence lengths
@ -579,6 +584,7 @@ class FlashLlamaModel(torch.nn.Module):
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
) )
slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -588,6 +594,13 @@ class FlashLlamaModel(torch.nn.Module):
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
# We added padding that now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
@ -595,7 +608,7 @@ class FlashLlamaModel(torch.nn.Module):
sin, sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values[i], layer_past_key_values,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
) )
@ -638,10 +651,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values=None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states, present = self.model( hidden_states, present = self.model(
input_ids, position_ids, cu_seqlens, max_s, past_key_values input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values,
pre_allocate_past_size,
) )
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -27,6 +27,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional
# Flash attention imports # Flash attention imports
import rotary_emb import rotary_emb
@ -618,6 +619,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values=None, past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
@ -627,7 +629,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_empty(
( (
len(self.layers), len(self.layers),
len(hidden_states), len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2, 2,
self.num_heads, self.num_heads,
self.head_size, self.head_size,
@ -635,6 +639,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
) )
layer_past_present_indices = None layer_past_present_indices = None
cu_seqlens_q = None cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths # Create indices from cumulative sequence lengths
@ -642,6 +647,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
) )
slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -651,6 +657,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
# We added padding that now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
@ -658,7 +671,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
sin, sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values[i], layer_past_key_values,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
) )
@ -714,10 +727,16 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids, position_ids,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values=None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states, present = self.gpt_neox( hidden_states, present = self.gpt_neox(
input_ids, position_ids, cu_seqlens, max_s, past_key_values input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values,
pre_allocate_past_size,
) )
logits = self.embed_out(hidden_states) logits = self.embed_out(hidden_states)

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
@ -484,7 +485,8 @@ class FlashSantacoderModel(nn.Module):
position_ids, position_ids,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values=None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.tp_embeddings: if self.tp_embeddings:
@ -496,7 +498,9 @@ class FlashSantacoderModel(nn.Module):
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_empty(
( (
len(self.h), len(self.h),
len(hidden_states), len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2, 2,
1, 1,
self.head_size, self.head_size,
@ -504,6 +508,7 @@ class FlashSantacoderModel(nn.Module):
) )
layer_past_present_indices = None layer_past_present_indices = None
cu_seqlens_q = None cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths # Create indices from cumulative sequence lengths
@ -511,15 +516,23 @@ class FlashSantacoderModel(nn.Module):
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
) )
slice_past_index = None
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
# We added padding that now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values[i], layer_past_key_values,
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q, cu_seqlens_q,
) )
@ -554,10 +567,16 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids, position_ids,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values=None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, position_ids, cu_seqlens, max_s, past_key_values input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values,
pre_allocate_past_size,
) )
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -142,6 +142,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
past_pad=None,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -188,8 +189,10 @@ class FlashCausalLMBatch(Batch):
cu_seqlens.append(cumulative_length + request_input_length) cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
if not single_request: if not single_request:
# True index for past
past_key_values.append(self.past_key_values[2 * idx]) past_key_values.append(self.past_key_values[2 * idx])
past_key_values.append(self.past_key_values[1]) # Add one padding
past_key_values.append(self.past_pad)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
@ -207,7 +210,17 @@ class FlashCausalLMBatch(Batch):
# Preallocate tensor for bs = 1 case # Preallocate tensor for bs = 1 case
past_key_values = torch.nn.functional.pad( past_key_values = torch.nn.functional.pad(
self.past_key_values[0], self.past_key_values[0],
(0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens) (
0,
0,
0,
0,
0,
0,
0,
stopping_criterias[0].max_new_tokens
- stopping_criterias[0].current_tokens,
),
) )
return FlashCausalLMBatch( return FlashCausalLMBatch(
@ -270,10 +283,16 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen) max_seqlen = max(max_seqlen, batch.max_seqlen)
if len(batch) != 1: if len(batch) != 1:
past_key_values.extend(batch.past_key_values) past_key_values.extend(batch.past_key_values)
else: else:
past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]]) # past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values.append(
batch.past_key_values[:, : batch.input_lengths[0]]
)
# Add one padding
past_key_values.append(batch.past_pad) past_key_values.append(batch.past_pad)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
@ -366,6 +385,7 @@ class FlashCausalLM(Model):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
max_s: int, max_s: int,
past_key_values: Optional = None, past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
return self.model.forward( return self.model.forward(
@ -374,6 +394,7 @@ class FlashCausalLM(Model):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_s=max_s, max_s=max_s,
past_key_values=past_key_values, past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
) )
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
@ -382,7 +403,9 @@ class FlashCausalLM(Model):
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# Shortcut when batch_size == 1 # Shortcut when batch_size == 1
if len(batch) == 1: if len(batch) == 1:
# No need to slice this down input_ids = batch.input_ids[0].view(-1)
# Slice to remove extra padding
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
past_key_values = batch.past_key_values past_key_values = batch.past_key_values
else: else:
# Concatenate tensors # Concatenate tensors
@ -393,6 +416,16 @@ class FlashCausalLM(Model):
else None else None
) )
# if prefill and bs == 1
if past_key_values is None and len(batch) == 1:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size = (
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
)
else:
pre_allocate_past_size = None
# Concatenate when prefill, torch.tensor when decode # Concatenate when prefill, torch.tensor when decode
position_ids = ( position_ids = (
torch.tensor(batch.position_ids, device=self.device) torch.tensor(batch.position_ids, device=self.device)
@ -409,21 +442,28 @@ class FlashCausalLM(Model):
cu_seqlens, cu_seqlens,
batch.max_seqlen, batch.max_seqlen,
past_key_values, past_key_values,
pre_allocate_past_size,
) )
# Initialize past_key_values in prefill # Initialize past_key_values in prefill
if batch.past_key_values is None: if batch.past_key_values is None:
# Initialize past padding tensor # Initialize past padding tensor
if self.past_pad is None: if self.past_pad is None:
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:]) self.past_pad = present.new_zeros(
present.shape[0], 1, *present.shape[2:]
)
# Set in batch in case it needs to be used later in concatenate() # Set in batch in case it needs to be used later in concatenate()
batch.past_pad = self.past_pad batch.past_pad = self.past_pad
if len(batch) == 1: if len(batch) == 1:
# Preallocate tensor for bs = 1 case # Preallocate tensor for bs = 1 case
batch.past_key_values = torch.nn.functional.pad( batch.past_key_values = torch.nn.functional.pad(
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens) present,
(0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens),
) )
else: else:
# Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model
# forward
batch.past_key_values = [None, self.past_pad] * len(batch) batch.past_key_values = [None, self.past_pad] * len(batch)
# Cumulative length # Cumulative length
@ -555,6 +595,7 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor[i] = all_input_ids_tensor batch.all_input_ids_tensor[i] = all_input_ids_tensor
batch.max_seqlen = max(batch.max_seqlen, new_input_length) batch.max_seqlen = max(batch.max_seqlen, new_input_length)
if len(batch) != 1: if len(batch) != 1:
# Add each sequence before its padding
batch.past_key_values[i * 2] = present[:, start_index:end_index] batch.past_key_values[i * 2] = present[:, start_index:end_index]
# Cumulative sum # Cumulative sum
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length

View File

@ -29,6 +29,7 @@ tracer = trace.get_tracer(__name__)
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
self.past_pad = None
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
@ -146,6 +147,7 @@ class FlashLlamaSharded(FlashLlama):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -33,6 +33,7 @@ class FlashNeoXSharded(FlashNeoX):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -28,6 +28,7 @@ tracer = trace.get_tracer(__name__)
class FlashSantacoder(FlashCausalLM): class FlashSantacoder(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
self.past_pad = None
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
@ -172,6 +173,7 @@ class FlashSantacoderSharded(FlashSantacoder):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():