Fixing cohere flash decoding.

This commit is contained in:
Nicolas Patry 2024-05-29 16:04:36 +00:00
parent a6f1603525
commit 7890cd66f7

View File

@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -255,8 +256,9 @@ class FlashCohereAttention(torch.nn.Module):
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
slots,
input_lengths,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -308,8 +310,8 @@ class FlashCohereAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
cu_seqlen_q,
cu_seqlen_k,
max_s,
)
@ -383,8 +385,9 @@ class FlashCohereLayer(nn.Module):
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
slots,
input_lengths,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -397,8 +400,9 @@ class FlashCohereLayer(nn.Module):
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
slots,
input_lengths,
max_s,
)
@ -461,6 +465,24 @@ class FlashCohereModel(torch.nn.Module):
)
residual = None
if cu_seqlen_prefill is None and FLASH_DECODING:
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1,
device=input_ids.device,
dtype=torch.int32,
)
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
else:
cu_seqlen_q = None
cu_seqlen_k = input_lengths
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -470,8 +492,9 @@ class FlashCohereModel(torch.nn.Module):
cu_seqlen_prefill,
kv_cache[i],
block_tables,
cu_seqlen_q,
cu_seqlen_k,
slots,
input_lengths,
max_s,
)