diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 5cc4e782..24499d09 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -26,6 +26,7 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -255,8 +256,9 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, + cu_seqlen_q, + cu_seqlen_k, slots, - input_lengths, max_s, ): qkv = self.query_key_value(hidden_states) @@ -308,8 +310,8 @@ class FlashCohereAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, - input_lengths, + cu_seqlen_q, + cu_seqlen_k, max_s, ) @@ -383,8 +385,9 @@ class FlashCohereLayer(nn.Module): cu_seqlen_prefill, kv_cache, block_tables, + cu_seqlen_q, + cu_seqlen_k, slots, - input_lengths, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -397,8 +400,9 @@ class FlashCohereLayer(nn.Module): cu_seqlen_prefill, kv_cache, block_tables, + cu_seqlen_q, + cu_seqlen_k, slots, - input_lengths, max_s, ) @@ -461,6 +465,24 @@ class FlashCohereModel(torch.nn.Module): ) residual = None + if cu_seqlen_prefill is None and FLASH_DECODING: + cu_seqlen_q = torch.arange( + input_lengths.shape[0] + 1, + device=input_ids.device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.cat( + [ + torch.zeros( + (1,), device=input_lengths.device, dtype=input_lengths.dtype + ), + input_lengths.cumsum(dim=-1), + ] + ).to(dtype=torch.int32) + else: + cu_seqlen_q = None + cu_seqlen_k = input_lengths + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -470,8 +492,9 @@ class FlashCohereModel(torch.nn.Module): cu_seqlen_prefill, kv_cache[i], block_tables, + cu_seqlen_q, + cu_seqlen_k, slots, - input_lengths, max_s, )