diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index ad894bfa..8c565a43 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -1,9 +1,9 @@ -flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7 +flash_att_commit := c5b2a9b7baba2d3059888dbeb03a3cea7aba6e1d flash-attention: # Clone flash attention pip install packaging - git clone https://github.com/HazyResearch/flash-attention.git + git clone https://github.com/OlivierDehaene/flash-attention.git build-flash-attention: flash-attention cd flash-attention && git fetch && git checkout $(flash_att_commit) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index de1627b1..0f1a1a54 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,7 +26,7 @@ from transformers.activations import ACT2FN from typing import Optional # Flash attention imports -import flash_attn_cuda_modif +import flash_attn_cuda import dropout_layer_norm from text_generation_server.utils.layers import ( @@ -128,34 +128,42 @@ class FlashLlamaAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + query, kv = qkv.split([1, 2], dim=1) + query = query.view(-1, self.num_heads, self.head_size) + # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past - layer_past[...] = qkv[:, 1:] + layer_past[past_present_indices] = kv # output - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + flash_attn_cuda.fwd( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -168,20 +176,21 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv[:, 1:] + layer_past[past_present_indices] = kv # output attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( + flash_attn_cuda.fwd( query, - layer_past[:, 0], - layer_past[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -258,11 +267,14 @@ class FlashLlamaLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -271,11 +283,14 @@ class FlashLlamaLayer(nn.Module): normed_hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) # faster post attention rms norm @@ -322,35 +337,36 @@ class FlashLlamaModel(torch.nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - past_key_values: Optional[torch.Tensor] = None, + past_present_indices, + past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.embed_tokens(input_ids) # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor past_key_values = hidden_states.new_empty( ( + pre_allocate_past_size, len(self.layers), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, 2, self.num_heads, self.head_size, ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -360,23 +376,19 @@ class FlashLlamaModel(torch.nn.Module): residual = None for i, layer in enumerate(self.layers): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + torch.select(past_key_values, dim=1, index=i), + past_present_indices, + prefill, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -399,9 +411,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -409,9 +424,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): hidden_states, present = self.model( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index e0aa2cb8..55541e45 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -28,7 +28,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional # Flash attention imports -import flash_attn_cuda_modif +import flash_attn_cuda from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -113,34 +113,42 @@ class FlashNeoxAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + query, kv = qkv.split([1, 2], dim=1) + query = query.view(-1, self.num_heads, self.head_size) + # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past - layer_past[...] = qkv[:, 1:] + layer_past[past_present_indices] = kv # output - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + flash_attn_cuda.fwd( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -153,20 +161,21 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv[:, 1:] + layer_past[past_present_indices] = kv # output attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( + flash_attn_cuda.fwd( query, - layer_past[:, 0], - layer_past[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -240,11 +249,14 @@ class FlashNeoXLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -253,11 +265,14 @@ class FlashNeoXLayer(nn.Module): ln1_hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -275,11 +290,14 @@ class FlashNeoXLayer(nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -328,9 +346,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): @@ -338,25 +359,23 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor past_key_values = hidden_states.new_empty( ( + pre_allocate_past_size, len(self.layers), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, 2, self.num_heads, self.head_size, ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -366,23 +385,19 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual = None for i, layer in enumerate(self.layers): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + torch.select(past_key_values, dim=1, index=i), + past_present_indices, + prefill, ) hidden_states, _ = self.final_layer_norm(hidden_states, residual) @@ -403,9 +418,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -413,9 +431,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): hidden_states, present = self.gpt_neox( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index c65fd160..e7665c8d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional # Flash attention imports -import flash_attn_cuda_modif +import flash_attn_cuda from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -42,25 +42,25 @@ class RWConfig(PretrainedConfig): } def __init__( - self, - model_type="RefinedWeb", - vocab_size=250880, - hidden_size=64, - n_layer=2, - n_head=8, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - use_cache=True, - bos_token_id=1, - eos_token_id=2, - hidden_dropout=0.0, - attention_dropout=0.0, - n_head_kv=None, - multi_query=False, - alibi=False, - bias=False, - parallel_attn=False, - **kwargs, + self, + model_type="RefinedWeb", + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + n_head_kv=None, + multi_query=False, + alibi=False, + bias=False, + parallel_attn=False, + **kwargs, ): if alibi: raise NotImplementedError( @@ -126,18 +126,18 @@ class FlashRWAttention(torch.nn.Module): ) def forward( - self, - hidden_states, - cos, - sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - layer_past, - past_present_indices, - prefill, + self, + hidden_states, + cos, + sin, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + layer_past, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) @@ -165,7 +165,7 @@ class FlashRWAttention(torch.nn.Module): # output attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( + flash_attn_cuda.fwd( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -194,7 +194,7 @@ class FlashRWAttention(torch.nn.Module): # output attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( + flash_attn_cuda.fwd( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -264,15 +264,18 @@ class FlashRWLargeAttention(torch.nn.Module): ) def forward( - self, - hidden_states, - cos, - sin, - cu_seqlens, - max_s, - layer_past, - past_present_indices, - cu_seqlens_q, + self, + hidden_states, + cos, + sin, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + layer_past, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -287,12 +290,12 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, :, 0], cos, sin) + self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) # Prefill - if past_present_indices is None: + if prefill: # Copy to layer past - layer_past[...] = kv + layer_past[past_present_indices] = kv # Expand to query shape kv = ( kv.unsqueeze(2) @@ -303,13 +306,15 @@ class FlashRWLargeAttention(torch.nn.Module): # output attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( + flash_attn_cuda.fwd( query, - kv[:, :, 0], - kv[:, :, 1], + torch.select(kv, dim=2, index=0), + torch.select(kv, dim=2, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -334,13 +339,15 @@ class FlashRWLargeAttention(torch.nn.Module): # output attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( + flash_attn_cuda.fwd( query, - kv[:, :, 0], - kv[:, :, 1], + torch.select(kv, dim=2, index=0), + torch.select(kv, dim=2, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -419,19 +426,19 @@ class FlashRWLayer(nn.Module): self.process_group = weights.process_group def forward( - self, - hidden_states, - residual, - cos, - sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - layer_past, - past_present_indices, - prefill, + self, + hidden_states, + residual, + cos, + sin, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + layer_past, + past_present_indices, + prefill, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -509,16 +516,19 @@ class FlashRWLargeLayer(nn.Module): self.process_group = weights.process_group def forward( - self, - hidden_states, - residual, - cos, - sin, - cu_seqlens, - max_s, - layer_past, - past_present_indices, - cu_seqlens_q, + self, + hidden_states, + residual, + cos, + sin, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + layer_past, + past_present_indices, + prefill, ): ln_attn, residual = self.ln_attn(hidden_states, residual) ln_mlp, _ = self.ln_mlp(residual) @@ -528,11 +538,14 @@ class FlashRWLargeLayer(nn.Module): ln_attn, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, past_present_indices, - cu_seqlens_q, + prefill, ) # MLP. @@ -570,7 +583,6 @@ class FlashRWModel(FlashRWPreTrainedModel): self.h[0].self_attention.head_size, ) elif config.model_type == "RefinedWeb": - raise NotImplementedError self.h = nn.ModuleList( [ FlashRWLargeLayer(layer_id, config, weights) @@ -596,17 +608,17 @@ class FlashRWModel(FlashRWPreTrainedModel): self.head_size = self.h[0].self_attention.head_size def forward( - self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, + self, + input_ids, + position_ids, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_present_indices, + past_key_values=None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.word_embeddings(input_ids) @@ -617,7 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel): prefill = True # Create past tensor - past_key_values = hidden_states.new_zeros( + past_key_values = hidden_states.new_empty( ( pre_allocate_past_size, len(self.h), @@ -646,7 +658,7 @@ class FlashRWModel(FlashRWPreTrainedModel): start_seq_q, end_seq_q, max_s, - past_key_values[:, i], + torch.select(past_key_values, dim=1, index=i), past_present_indices, prefill, ) @@ -667,18 +679,18 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ) def forward( - self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, - lm_head_indices: Optional[torch.Tensor] = None, + self, + input_ids, + position_ids, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_present_indices, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 70e02e76..2ccdf045 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -6,7 +6,7 @@ from transformers.activations import ACT2FN from typing import Optional # Flash attention imports -import flash_attn_cuda_modif +import flash_attn_cuda from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -179,7 +179,7 @@ class FlashMQAttention(torch.nn.Module): # output attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( + flash_attn_cuda.fwd( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), @@ -208,7 +208,7 @@ class FlashMQAttention(torch.nn.Module): # output attn_output = torch.empty_like(query) # flash attention - flash_attn_cuda_modif.fwd( + flash_attn_cuda.fwd( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), @@ -373,13 +373,7 @@ class FlashSantacoderModel(nn.Module): # Create past tensor past_key_values = hidden_states.new_zeros( - ( - pre_allocate_past_size, - len(self.h), - 2, - 1, - self.head_size - ) + (pre_allocate_past_size, len(self.h), 2, 1, self.head_size) ) # Decode else: @@ -415,17 +409,17 @@ class FlashSantacoderForCausalLM(nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, - lm_head_indices: Optional[torch.Tensor] = None, + input_ids, + position_ids, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_present_indices, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bdfa5051..ecea998e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -82,11 +82,11 @@ class FlashCausalLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, ) -> "FlashCausalLMBatch": batch_inputs = [] max_truncation = 0 @@ -184,7 +184,11 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - request_past_present_indices = torch.arange(cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64) + request_past_present_indices = torch.arange( + cumulative_max_length, + cumulative_max_length + input_length, + dtype=torch.int64, + ) past_present_indices.append(request_past_present_indices) # Update @@ -217,8 +221,12 @@ class FlashCausalLMBatch(Batch): past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) - start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32) - end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32) + start_seq_prefill = torch.tensor( + start_seq_prefill, device=device, dtype=torch.int32 + ) + end_seq_prefill = torch.tensor( + end_seq_prefill, device=device, dtype=torch.int32 + ) else: input_ids = all_input_ids[0] position_ids = position_ids[0] @@ -230,7 +238,9 @@ class FlashCausalLMBatch(Batch): input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) - past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64) + past_present_indices = torch.tensor( + past_present_indices, device=device, dtype=torch.int64 + ) if all_prefill_logprobs: prefill_head_indices = None @@ -294,7 +304,9 @@ class FlashCausalLMBatch(Batch): indices = [] # past indices to keep - past_indices = torch.zeros(self.past_key_values.shape[0], dtype=torch.bool, device=device) + past_indices = torch.zeros( + self.past_key_values.shape[0], dtype=torch.bool, device=device + ) # Create on CPU to only move to GPU once instead of at every copy start_seq = torch.empty(len(request_ids), dtype=torch.int32) @@ -332,14 +344,18 @@ class FlashCausalLMBatch(Batch): stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + remaining_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) # Copy to tensor (CPU) start_seq[i] = cumulative_max_length end_seq[i] = cumulative_max_length + request_input_length # Set slice - past_indices[self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1] = True + past_indices[ + self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1 + ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 @@ -480,7 +496,7 @@ class FlashCausalLMBatch(Batch): end_index = cumulative_batch_size + len(batch) all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] + start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] cumulative_batch_size += len(batch) @@ -523,12 +539,12 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( - self, - model_cls: Type[PreTrainedModel], - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, + self, + model_cls: Type[PreTrainedModel], + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -569,18 +585,18 @@ class FlashCausalLM(Model): ) def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - start_seq: torch.Tensor, - end_seq: torch.Tensor, - start_seq_q: Optional[torch.Tensor], - end_seq_q: Optional[torch.Tensor], - max_s: int, - past_present_indices: torch.Tensor, - past_key_values: Optional = None, - pre_allocate_past_size: Optional[int] = None, - lm_head_indices: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq: torch.Tensor, + end_seq: torch.Tensor, + start_seq_q: Optional[torch.Tensor], + end_seq_q: Optional[torch.Tensor], + max_s: int, + past_present_indices: torch.Tensor, + past_key_values: Optional = None, + pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( @@ -599,7 +615,7 @@ class FlashCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: FlashCausalLMBatch + self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None prefill_logprobs = batch.prefill_next_token_indices is not None @@ -647,7 +663,9 @@ class FlashCausalLM(Model): prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) # Create batch.start_seq_q and batch.end_seq_q for decode - batch.start_seq_q = torch.arange(0, len(batch), device=self.device, dtype=torch.int32) + batch.start_seq_q = torch.arange( + 0, len(batch), device=self.device, dtype=torch.int32 + ) batch.end_seq_q = batch.start_seq_q + 1 next_position_ids = batch.position_ids.new_empty(len(batch)) # We do not need start_seq_prefill and end_seq_prefill anymore @@ -783,7 +801,7 @@ class FlashCausalLM(Model): if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens:] + all_input_ids[-stopping_criteria.current_tokens :] ) generated_text = GeneratedText( output_text,