add other models

This commit is contained in:
OlivierDehaene 2023-06-05 16:55:53 +02:00
parent 3fc87f93bd
commit c509e4e79d
6 changed files with 337 additions and 274 deletions

View File

@ -1,9 +1,9 @@
flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7 flash_att_commit := c5b2a9b7baba2d3059888dbeb03a3cea7aba6e1d
flash-attention: flash-attention:
# Clone flash attention # Clone flash attention
pip install packaging pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git git clone https://github.com/OlivierDehaene/flash-attention.git
build-flash-attention: flash-attention build-flash-attention: flash-attention
cd flash-attention && git fetch && git checkout $(flash_att_commit) cd flash-attention && git fetch && git checkout $(flash_att_commit)

View File

@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from typing import Optional from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda_modif import flash_attn_cuda
import dropout_layer_norm import dropout_layer_norm
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
@ -128,34 +128,42 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split([1, 2], dim=1)
query = query.view(-1, self.num_heads, self.head_size)
# Inplace rotary # Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv[:, 1:] layer_past[past_present_indices] = kv
# output # output
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
qkv[:, 0], query,
qkv[:, 1], torch.select(kv, dim=1, index=0),
qkv[:, 2], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -168,20 +176,21 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:] layer_past[past_present_indices] = kv
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
query, query,
layer_past[:, 0], torch.select(kv, dim=1, index=0),
layer_past[:, 1], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
@ -258,11 +267,14 @@ class FlashLlamaLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -271,11 +283,14 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states, normed_hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
# faster post attention rms norm # faster post attention rms norm
@ -322,35 +337,36 @@ class FlashLlamaModel(torch.nn.Module):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor # Create past tensor
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_empty(
( (
pre_allocate_past_size,
len(self.layers), len(self.layers),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2, 2,
self.num_heads, self.num_heads,
self.head_size, self.head_size,
) )
) )
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths prefill = False
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -360,23 +376,19 @@ class FlashLlamaModel(torch.nn.Module):
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past_key_values, torch.select(past_key_values, dim=1, index=i),
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -399,9 +411,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -409,9 +424,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states, present = self.model( hidden_states, present = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )

View File

@ -28,7 +28,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda_modif import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -113,34 +113,42 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split([1, 2], dim=1)
query = query.view(-1, self.num_heads, self.head_size)
# Inplace rotary # Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv[:, 1:] layer_past[past_present_indices] = kv
# output # output
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
qkv[:, 0], query,
qkv[:, 1], torch.select(kv, dim=1, index=0),
qkv[:, 2], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -153,20 +161,21 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:] layer_past[past_present_indices] = kv
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
query, query,
layer_past[:, 0], torch.select(kv, dim=1, index=0),
layer_past[:, 1], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
@ -240,11 +249,14 @@ class FlashNeoXLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states) ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -253,11 +265,14 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states, ln1_hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -275,11 +290,14 @@ class FlashNeoXLayer(nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -328,9 +346,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values=None, past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
@ -338,25 +359,23 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor # Create past tensor
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_empty(
( (
pre_allocate_past_size,
len(self.layers), len(self.layers),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2, 2,
self.num_heads, self.num_heads,
self.head_size, self.head_size,
) )
) )
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths prefill = False
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -366,23 +385,19 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past_key_values, torch.select(past_key_values, dim=1, index=i),
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)
@ -403,9 +418,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -413,9 +431,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
hidden_states, present = self.gpt_neox( hidden_states, present = self.gpt_neox(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )

View File

@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda_modif import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -165,7 +165,7 @@ class FlashRWAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
@ -194,7 +194,7 @@ class FlashRWAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
@ -268,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -287,12 +290,12 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, :, 0], cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
# Prefill # Prefill
if past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = kv layer_past[past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = ( kv = (
kv.unsqueeze(2) kv.unsqueeze(2)
@ -303,13 +306,15 @@ class FlashRWLargeAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, :, 0], torch.select(kv, dim=2, index=0),
kv[:, :, 1], torch.select(kv, dim=2, index=1),
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -334,13 +339,15 @@ class FlashRWLargeAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, :, 0], torch.select(kv, dim=2, index=0),
kv[:, :, 1], torch.select(kv, dim=2, index=1),
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
@ -514,11 +521,14 @@ class FlashRWLargeLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
ln_attn, residual = self.ln_attn(hidden_states, residual) ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual) ln_mlp, _ = self.ln_mlp(residual)
@ -528,11 +538,14 @@ class FlashRWLargeLayer(nn.Module):
ln_attn, ln_attn,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
# MLP. # MLP.
@ -570,7 +583,6 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.h[0].self_attention.head_size, self.h[0].self_attention.head_size,
) )
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
raise NotImplementedError
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer(layer_id, config, weights) FlashRWLargeLayer(layer_id, config, weights)
@ -617,7 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
prefill = True prefill = True
# Create past tensor # Create past tensor
past_key_values = hidden_states.new_zeros( past_key_values = hidden_states.new_empty(
( (
pre_allocate_past_size, pre_allocate_past_size,
len(self.h), len(self.h),
@ -646,7 +658,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
start_seq_q, start_seq_q,
end_seq_q, end_seq_q,
max_s, max_s,
past_key_values[:, i], torch.select(past_key_values, dim=1, index=i),
past_present_indices, past_present_indices,
prefill, prefill,
) )

View File

@ -6,7 +6,7 @@ from transformers.activations import ACT2FN
from typing import Optional from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda_modif import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -179,7 +179,7 @@ class FlashMQAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
@ -208,7 +208,7 @@ class FlashMQAttention(torch.nn.Module):
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda_modif.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
@ -373,13 +373,7 @@ class FlashSantacoderModel(nn.Module):
# Create past tensor # Create past tensor
past_key_values = hidden_states.new_zeros( past_key_values = hidden_states.new_zeros(
( (pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
pre_allocate_past_size,
len(self.h),
2,
1,
self.head_size
)
) )
# Decode # Decode
else: else:

View File

@ -184,7 +184,11 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
request_past_present_indices = torch.arange(cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64) request_past_present_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
past_present_indices.append(request_past_present_indices) past_present_indices.append(request_past_present_indices)
# Update # Update
@ -217,8 +221,12 @@ class FlashCausalLMBatch(Batch):
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32) start_seq_prefill = torch.tensor(
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32) start_seq_prefill, device=device, dtype=torch.int32
)
end_seq_prefill = torch.tensor(
end_seq_prefill, device=device, dtype=torch.int32
)
else: else:
input_ids = all_input_ids[0] input_ids = all_input_ids[0]
position_ids = position_ids[0] position_ids = position_ids[0]
@ -230,7 +238,9 @@ class FlashCausalLMBatch(Batch):
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64) past_present_indices = torch.tensor(
past_present_indices, device=device, dtype=torch.int64
)
if all_prefill_logprobs: if all_prefill_logprobs:
prefill_head_indices = None prefill_head_indices = None
@ -294,7 +304,9 @@ class FlashCausalLMBatch(Batch):
indices = [] indices = []
# past indices to keep # past indices to keep
past_indices = torch.zeros(self.past_key_values.shape[0], dtype=torch.bool, device=device) past_indices = torch.zeros(
self.past_key_values.shape[0], dtype=torch.bool, device=device
)
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
start_seq = torch.empty(len(request_ids), dtype=torch.int32) start_seq = torch.empty(len(request_ids), dtype=torch.int32)
@ -332,14 +344,18 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# Copy to tensor (CPU) # Copy to tensor (CPU)
start_seq[i] = cumulative_max_length start_seq[i] = cumulative_max_length
end_seq[i] = cumulative_max_length + request_input_length end_seq[i] = cumulative_max_length + request_input_length
# Set slice # Set slice
past_indices[self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1] = True past_indices[
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
] = True
cumulative_max_length += request_input_length + remaining_tokens - 1 cumulative_max_length += request_input_length + remaining_tokens - 1
@ -647,7 +663,9 @@ class FlashCausalLM(Model):
prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.start_seq_q and batch.end_seq_q for decode # Create batch.start_seq_q and batch.end_seq_q for decode
batch.start_seq_q = torch.arange(0, len(batch), device=self.device, dtype=torch.int32) batch.start_seq_q = torch.arange(
0, len(batch), device=self.device, dtype=torch.int32
)
batch.end_seq_q = batch.start_seq_q + 1 batch.end_seq_q = batch.start_seq_q + 1
next_position_ids = batch.position_ids.new_empty(len(batch)) next_position_ids = batch.position_ids.new_empty(len(batch))
# We do not need start_seq_prefill and end_seq_prefill anymore # We do not need start_seq_prefill and end_seq_prefill anymore