mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-09 11:22:09 +00:00
Update all models.
This commit is contained in:
parent
65b94a69bd
commit
bb9769ed42
@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
@ -264,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -296,12 +297,10 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -313,7 +312,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -388,7 +387,7 @@ class FlashCohereLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -402,7 +401,7 @@ class FlashCohereLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -454,7 +453,7 @@ class FlashCohereModel(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -477,7 +476,7 @@ class FlashCohereModel(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -518,7 +517,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -531,7 +530,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
@ -309,7 +310,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -335,12 +336,10 @@ class DbrxAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -352,7 +351,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -389,7 +388,7 @@ class DbrxNormAttentionNorm(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.norm_1(hidden_states, residual)
|
||||
@ -403,7 +402,7 @@ class DbrxNormAttentionNorm(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -622,7 +621,7 @@ class DbrxLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
# Self Attention
|
||||
@ -635,7 +634,7 @@ class DbrxLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -679,7 +678,7 @@ class DbrxModel(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -701,7 +700,7 @@ class DbrxModel(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -734,7 +733,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -747,7 +746,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -29,8 +29,8 @@ from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers.attention.common import Seqlen
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
@ -298,7 +298,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
if self.q_lora_rank is None:
|
||||
@ -363,12 +363,10 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -380,7 +378,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -666,7 +664,7 @@ class DeepseekV2Layer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
@ -680,7 +678,7 @@ class DeepseekV2Layer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -729,7 +727,7 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -751,7 +749,7 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -781,7 +779,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -794,7 +792,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
causal=self.causal,
|
||||
window_size_left=self.window_size,
|
||||
@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
softcap=self.softcap,
|
||||
)
|
||||
@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
causal=self.causal,
|
||||
)
|
||||
@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
query, key, value = self.query_key_value(hidden_states).split(
|
||||
@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
residual = hidden_states
|
||||
@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
|
@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
query, key, value = self.query_key_value(hidden_states).split(
|
||||
@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -356,7 +355,7 @@ class MistralLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
@ -372,7 +371,7 @@ class MistralLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -498,7 +497,7 @@ class MixtralLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
@ -513,7 +512,7 @@ class MixtralLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
|
@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
if self.use_parallel_residual:
|
||||
@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
# Compute query, key, value and split
|
||||
@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
|
||||
|
||||
@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -206,12 +207,10 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -223,7 +222,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -296,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -343,7 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -429,7 +428,7 @@ class FlashRWLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
@ -443,7 +442,7 @@ class FlashRWLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -465,7 +464,7 @@ class FlashRWLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -552,7 +551,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
# Layer norm.
|
||||
@ -567,7 +566,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -628,7 +627,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
@ -650,7 +649,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -680,7 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -693,7 +692,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.c_attn(hidden_states)
|
||||
@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -373,7 +372,7 @@ class Block(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
@ -383,7 +382,7 @@ class Block(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -209,7 +210,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
@ -240,12 +241,10 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
@ -258,7 +257,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
@ -381,7 +380,7 @@ class Starcoder2Layer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
@ -396,7 +395,7 @@ class Starcoder2Layer(nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
@ -449,7 +448,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
@ -473,7 +472,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
@ -521,7 +520,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
@ -534,7 +533,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
@ -543,7 +542,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
|
Loading…
Reference in New Issue
Block a user