mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing cohere flash decoding.
This commit is contained in:
parent
a6f1603525
commit
7890cd66f7
@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -255,8 +256,9 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
@ -308,8 +310,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
cu_seqlen_q,
|
||||||
input_lengths,
|
cu_seqlen_k,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -383,8 +385,9 @@ class FlashCohereLayer(nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
@ -397,8 +400,9 @@ class FlashCohereLayer(nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -461,6 +465,24 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
if cu_seqlen_prefill is None and FLASH_DECODING:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
input_lengths.shape[0] + 1,
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlen_k = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
||||||
|
),
|
||||||
|
input_lengths.cumsum(dim=-1),
|
||||||
|
]
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
cu_seqlen_q = None
|
||||||
|
cu_seqlen_k = input_lengths
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -470,8 +492,9 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user