remove block_tables and prefill_cache_indices which will lead to dynamic shape

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-27 22:51:21 -07:00
parent 7900be5ac3
commit 1508ee8de1
27 changed files with 88 additions and 530 deletions

View File

@ -26,7 +26,6 @@ def attention(
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
@ -61,7 +60,6 @@ def paged_attention(
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
*,
kv_scales: KVScales,

View File

@ -219,10 +219,8 @@ class FlashCohereAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -247,16 +245,9 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key_to_cache,
value=value_to_cache,
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
@ -271,7 +262,6 @@ class FlashCohereAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -281,7 +271,6 @@ class FlashCohereAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -356,10 +345,8 @@ class FlashCohereLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -371,10 +358,8 @@ class FlashCohereLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -424,10 +409,8 @@ class FlashCohereModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: torch.Tensor,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -446,10 +429,8 @@ class FlashCohereModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -488,10 +469,8 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -501,10 +480,8 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -308,10 +308,8 @@ class DbrxAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -329,14 +327,10 @@ class DbrxAttention(torch.nn.Module):
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -351,7 +345,6 @@ class DbrxAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -361,7 +354,6 @@ class DbrxAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -398,10 +390,8 @@ class DbrxNormAttentionNorm(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.norm_1(hidden_states, residual)
@ -413,10 +403,8 @@ class DbrxNormAttentionNorm(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -630,10 +618,8 @@ class DbrxLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
# Self Attention
@ -644,10 +630,8 @@ class DbrxLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -689,10 +673,8 @@ class DbrxModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -710,10 +692,8 @@ class DbrxModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -744,10 +724,8 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -757,10 +735,8 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -256,10 +256,8 @@ class DeepseekV2Attention(torch.nn.Module):
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
@ -316,15 +314,10 @@ class DeepseekV2Attention(torch.nn.Module):
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key_to_cache,
value=value_to_cache,
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
@ -339,7 +332,6 @@ class DeepseekV2Attention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -349,7 +341,6 @@ class DeepseekV2Attention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -512,10 +503,8 @@ class DeepseekV2Layer(nn.Module):
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -527,10 +516,8 @@ class DeepseekV2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -577,10 +564,8 @@ class DeepseekV2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -598,10 +583,8 @@ class DeepseekV2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -629,10 +612,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -642,10 +623,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -256,10 +256,8 @@ class DeepseekV3Attention(torch.nn.Module):
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
@ -317,15 +315,9 @@ class DeepseekV3Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key_to_cache,
value=value_to_cache,
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
@ -340,7 +332,6 @@ class DeepseekV3Attention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -350,7 +341,6 @@ class DeepseekV3Attention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -522,10 +512,8 @@ class DeepseekV3Layer(nn.Module):
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -537,10 +525,8 @@ class DeepseekV3Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -587,10 +573,8 @@ class DeepseekV3Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -608,10 +592,8 @@ class DeepseekV3Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -639,10 +621,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -652,10 +632,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -235,11 +235,9 @@ class FlashGemma2Attention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
@ -254,14 +252,10 @@ class FlashGemma2Attention(torch.nn.Module):
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -276,7 +270,6 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.window_size,
softcap=self.softcap,
@ -288,7 +281,6 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
softcap=self.softcap,
kv_scales=self.kv_scales,
@ -402,11 +394,9 @@ class FlashGemma2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -418,11 +408,9 @@ class FlashGemma2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
)
@ -472,11 +460,9 @@ class FlashGemma2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -494,11 +480,9 @@ class FlashGemma2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
)
@ -543,10 +527,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -557,11 +539,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -207,10 +207,8 @@ class FlashGemmaAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -226,14 +224,9 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -248,7 +241,6 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
causal=self.causal,
)
@ -259,7 +251,6 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -331,10 +322,8 @@ class FlashGemmaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -346,10 +335,8 @@ class FlashGemmaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -395,11 +382,9 @@ class FlashGemmaModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -417,10 +402,8 @@ class FlashGemmaModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -463,10 +446,8 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -477,11 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -209,10 +209,8 @@ class FlashGPT2Attention(torch.nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
query, key, value = self.query_key_value(hidden_states).split(
@ -222,16 +220,9 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key_to_cache,
value=value_to_cache,
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
@ -246,7 +237,6 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -256,7 +246,6 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -325,10 +314,8 @@ class FlashGPT2Layer(nn.Module):
residual,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
residual = hidden_states
@ -339,10 +326,8 @@ class FlashGPT2Layer(nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -393,10 +378,8 @@ class FlashGPT2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -408,10 +391,8 @@ class FlashGPT2Model(torch.nn.Module):
residual,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -446,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -462,10 +441,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices=prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -201,11 +201,9 @@ class FlashLlamaAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache: KVCache,
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
qkv = self.query_key_value(hidden_states, adapter_data)
@ -221,14 +219,9 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -243,7 +236,6 @@ class FlashLlamaAttention(torch.nn.Module):
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -253,7 +245,6 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -441,12 +432,10 @@ class FlashLlamaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
cross_attention_states,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -458,11 +447,9 @@ class FlashLlamaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta,
)
if self.residual_multiplier is not None:
@ -554,10 +541,8 @@ class FlashLlamaModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
@ -577,12 +562,10 @@ class FlashLlamaModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
adapter_data,
cross_attention_states,
prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta,
)
@ -643,30 +626,21 @@ class FlashLlamaForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
hpu_attention_meta=hpu_attention_meta,

View File

@ -169,11 +169,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
@ -276,11 +274,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:

View File

@ -178,10 +178,8 @@ class MistralAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
@ -198,14 +196,9 @@ class MistralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -220,7 +213,6 @@ class MistralAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
@ -231,7 +223,6 @@ class MistralAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -335,10 +326,8 @@ class MistralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
@ -351,10 +340,8 @@ class MistralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
@ -403,10 +390,8 @@ class MistralModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None,
):
@ -424,10 +409,8 @@ class MistralModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
@ -475,30 +458,20 @@ class FlashMistralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
adapter_data,
)

View File

@ -235,10 +235,8 @@ class MixtralAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -254,14 +252,9 @@ class MixtralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -276,7 +269,6 @@ class MixtralAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
@ -287,7 +279,6 @@ class MixtralAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -384,10 +375,8 @@ class MixtralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -399,10 +388,8 @@ class MixtralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -454,10 +441,8 @@ class MixtralModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -475,10 +460,8 @@ class MixtralModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -510,29 +493,20 @@ class FlashMixtralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -801,12 +801,10 @@ class FlashLlamaCrossLayer(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
cross_attention_states, # [ IB, ...]
prefill_cache_indices,
hpu_attention_meta,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
@ -911,11 +909,9 @@ class FlashMllamaForConditionalGeneration(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through.
@ -979,11 +975,9 @@ class FlashMllamaForConditionalGeneration(nn.Module):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,

View File

@ -147,10 +147,8 @@ class FlashNeoxAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -166,14 +164,10 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(query_rot, key_rot, cos, sin)
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
if prefill_cache_indices is not None:
qkv_to_cache = qkv[prefill_cache_indices]
else:
qkv_to_cache = qkv
kv_cache.store(
key=qkv_to_cache[:, 1],
value=qkv_to_cache[:, 2],
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
@ -188,7 +182,6 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -198,7 +191,6 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -268,10 +260,8 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
if self.use_parallel_residual:
@ -283,10 +273,8 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -308,10 +296,8 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -363,10 +349,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
@ -384,10 +368,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -417,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -430,10 +410,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -69,11 +69,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
@ -106,11 +104,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data,
)

View File

@ -160,10 +160,8 @@ class FlashPhiAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
# Compute query, key, value and split
@ -190,13 +188,9 @@ class FlashPhiAttention(torch.nn.Module):
)
# Reshape key and value and cache
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -210,7 +204,6 @@ class FlashPhiAttention(torch.nn.Module):
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -220,7 +213,6 @@ class FlashPhiAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -287,10 +279,8 @@ class FlashPhiLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -301,10 +291,8 @@ class FlashPhiLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -354,10 +342,8 @@ class FlashPhiModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -375,10 +361,8 @@ class FlashPhiModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -409,10 +393,8 @@ class FlashPhiForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -422,10 +404,8 @@ class FlashPhiForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -106,10 +106,8 @@ class Qwen2Attention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -125,14 +123,9 @@ class Qwen2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -147,7 +140,6 @@ class Qwen2Attention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
@ -158,7 +150,6 @@ class Qwen2Attention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -230,10 +221,8 @@ class Qwen2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states)
@ -245,10 +234,8 @@ class Qwen2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
hidden_states = attn_output + residual
@ -296,10 +283,8 @@ class Qwen2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -317,10 +302,8 @@ class Qwen2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -364,21 +347,13 @@ class Qwen2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if prefill_cache_indices is not None and prefill_cache_indices.size(
0
) != slots.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
@ -386,10 +361,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -182,10 +182,8 @@ class FlashRWAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -203,14 +201,9 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -225,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -235,7 +227,6 @@ class FlashRWAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -309,10 +300,8 @@ class FlashRWLargeAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
@ -329,14 +318,9 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, :, 0].contiguous(),
value=kv_to_cache[:, :, 1].contiguous(),
key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
)
@ -351,7 +335,6 @@ class FlashRWLargeAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -361,7 +344,6 @@ class FlashRWLargeAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -447,10 +429,8 @@ class FlashRWLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
if self.parallel_attn:
@ -462,10 +442,8 @@ class FlashRWLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -485,10 +463,8 @@ class FlashRWLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -573,10 +549,8 @@ class FlashRWLargeLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
# Layer norm.
@ -589,10 +563,8 @@ class FlashRWLargeLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -651,10 +623,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
@ -672,10 +642,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -703,10 +671,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -716,10 +682,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -265,10 +265,8 @@ class FlashMQAttention(torch.nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
qkv = self.c_attn(hidden_states)
@ -282,14 +280,9 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
if prefill_cache_indices is not None:
key_value_to_cache = key_value[prefill_cache_indices]
else:
key_value_to_cache = key_value
kv_cache.store(
key=key_value_to_cache[:, 0],
value=key_value_to_cache[:, 1],
key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -304,7 +297,6 @@ class FlashMQAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -314,7 +306,6 @@ class FlashMQAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -379,10 +370,8 @@ class Block(nn.Module):
residual,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
@ -390,10 +379,8 @@ class Block(nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -445,10 +432,8 @@ class FlashSantacoderModel(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -463,10 +448,8 @@ class FlashSantacoderModel(nn.Module):
residual,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -496,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -509,10 +490,8 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -235,10 +235,8 @@ class Starcoder2Attention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
@ -255,14 +253,9 @@ class Starcoder2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
@ -277,7 +270,6 @@ class Starcoder2Attention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
@ -288,7 +280,6 @@ class Starcoder2Attention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -448,10 +439,8 @@ class Starcoder2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
@ -464,10 +453,8 @@ class Starcoder2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
@ -518,10 +505,8 @@ class Starcoder2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
@ -540,10 +525,8 @@ class Starcoder2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
@ -589,29 +572,20 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)

View File

@ -740,11 +740,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
@ -843,11 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:

View File

@ -483,11 +483,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
@ -587,11 +585,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:

View File

@ -906,11 +906,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
@ -938,11 +936,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -480,11 +480,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
@ -511,11 +509,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -25,7 +25,7 @@ from typing import (
Dict,
Union,
)
import torch.nn.functional as F
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.models import Model
@ -116,7 +116,6 @@ def prepare_for_decode(
block_list = flatten(block_tables)
block_groups = flatten(block_groups)
block_usage = flatten(block_usage)
assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage)
if use_contiguous_pa:
@ -979,29 +978,27 @@ class FlashCausalLMBatch(Batch):
# padding to left to work with sliding window
# use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate
# the right logit position
input_ids_padded = None
input_ids_padded_length = None
input_ids_padded_length = []
# need extra pad to match warmup seq
extra_pad = 0
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded = []
input_ids_padded_length = []
input_ids = []
for input_id in self.input_ids:
padded = self.max_input_length - len(input_id)
input_id_padded = input_id
padded = self.max_input_length - len(input_id) + extra_pad
if padded > 0:
input_id_padded = [0] * padded + input_id_padded
input_ids_padded.append(input_id_padded)
input_id = [0] * padded + input_id
input_ids.append(input_id)
input_ids_padded_length.append(padded)
input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64)
input_ids_padded = torch.tensor(
input_ids_padded, dtype=torch.int64, device=device
)
if isinstance(self.input_ids, list):
if len(self) > 1:
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
else:
input_ids = self.input_ids[0]
input_ids = np.concatenate(input_ids, dtype=np.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad)
input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else:
logger.error("should not be here, prefill self.input_ids is a tensor")
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
@ -1052,10 +1049,9 @@ class FlashCausalLMBatch(Batch):
request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32
)
if input_ids_padded is not None:
position_ids.append(
torch.ones(input_ids_padded_length[i], dtype=torch.int32)
)
request_position_ids = F.pad(
request_position_ids, (input_ids_padded_length[i], 0), value=1
)
position_ids.append(request_position_ids)
if not r.slots:
@ -1079,12 +1075,11 @@ class FlashCausalLMBatch(Batch):
cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill
if input_ids_padded is not None:
# hpu need request_prefill_cache_indices to skip padding in kv cache
sliding_window = get_sliding_windows()
if sliding_window is None:
sliding_window = input_length
cumulative_length += input_ids_padded_length[i]
# hpu need request_prefill_cache_indices to skip padding in kv cache
sliding_window = get_sliding_windows()
if sliding_window is None:
sliding_window = input_length
cumulative_length += input_ids_padded_length[i]
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
@ -1105,8 +1100,7 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
if sliding_window is not None:
prefill_cache_indices.append(request_prefill_cache_indices)
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index()
if ADAPTER_TO_INDEX:
@ -1171,23 +1165,20 @@ class FlashCausalLMBatch(Batch):
position_ids = torch.cat(position_ids)
if slot_indices:
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
if position_ids:
position_ids = position_ids[0]
if slot_indices:
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
prefill_cache_indices = prefill_cache_indices[0]
self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device)
self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
self.prefill_cache_indices[prefill_cache_indices.to(device)] = True
if all_prefill_logprobs:
prefill_head_indices = None
@ -1203,21 +1194,19 @@ class FlashCausalLMBatch(Batch):
self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices
if input_ids_padded is not None:
self.input_ids = input_ids_padded
input_ids_padded_length_tensor = torch.cumsum(
torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),
dim=-1,
input_ids_padded_length_tensor = torch.cumsum(
torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),
dim=-1,
)
if self.prefill_head_indices is not None:
self.prefill_head_indices = (
self.prefill_head_indices + input_ids_padded_length_tensor
)
if self.prefill_head_indices is not None:
self.prefill_head_indices = (
self.prefill_head_indices + input_ids_padded_length_tensor
)
if self.prefill_next_token_indices is not None:
self.prefill_next_token_indices = (
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
if self.prefill_next_token_indices is not None:
self.prefill_next_token_indices = (
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to(
@ -1232,7 +1221,6 @@ class FlashCausalLMBatch(Batch):
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
@ -1490,14 +1478,6 @@ class FlashCausalLM(Model):
self.kv_cache_dtype,
self.device,
)
for bs in [1, 2, 4, 8]:
for seqlen in [32, 64, 128, 256, 512, 1024]:
self.warmup_prefill(seqlen, bs)
for bs in [1, 2, 4, 8]:
for block_num in [1, 2, 4, 8, 16]:
self.warmup_decode(bs, block_num * bs)
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def warmup_prefill(self, prompt_len: int, bs: int):
@ -1539,10 +1519,8 @@ class FlashCausalLM(Model):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=None,
lm_head_indices=lm_head_indices,
adapter_data=None,
hpu_attention_meta=None,
@ -1562,7 +1540,6 @@ class FlashCausalLM(Model):
for i in range(bs):
slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
cache_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
@ -1575,11 +1552,11 @@ class FlashCausalLM(Model):
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
block_num = cache_lengths_tensor // BLOCK_SIZE + 1
block_tables_valid = []
for i, bt in enumerate(block_tables.tolist()):
block_tables_valid.append(bt[0 : block_num[i]])
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
@ -1595,10 +1572,8 @@ class FlashCausalLM(Model):
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=None,
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta,
@ -1684,26 +1659,23 @@ class FlashCausalLM(Model):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=None,
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
# TODO not support adapter now, need the add in the future
adapter_data=None,
hpu_attention_meta=batch.hpu_attn_meta,
**kwargs,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
# fix following runtime error in graph replay
# RuntimeError: Neither storage attached to input tensor, not its view
htorch.core.mark_step()
return logits, speculative_logits
@tracer.start_as_current_span("generate_token")
@ -1801,7 +1773,14 @@ class FlashCausalLM(Model):
# instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling:
indices = batch.cu_seqlen_prefill[1:] - 1
batch.position_ids = batch.position_ids[indices]
# pad in left
if batch.prefill_cache_indices is not None:
batch.position_ids = batch.position_ids[batch.prefill_cache_indices][
indices
]
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices

View File

@ -462,11 +462,9 @@ class FlashVlmCausalLM(FlashCausalLM):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,

View File

@ -288,11 +288,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
# TODO list