remove block_tables and prefill_cache_indices which will lead to dynamic shape

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-27 22:51:21 -07:00
parent 7900be5ac3
commit 1508ee8de1
27 changed files with 88 additions and 530 deletions

View File

@ -26,7 +26,6 @@ def attention(
kv_cache: KVCache, kv_cache: KVCache,
kv_scales: KVScales, kv_scales: KVScales,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,
window_size_left: int = -1, window_size_left: int = -1,
causal: bool = True, causal: bool = True,
@ -61,7 +60,6 @@ def paged_attention(
kv_cache: KVCache, kv_cache: KVCache,
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
*, *,
kv_scales: KVScales, kv_scales: KVScales,

View File

@ -219,10 +219,8 @@ class FlashCohereAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -247,16 +245,9 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store( kv_cache.store(
key=key_to_cache, key=key,
value=value_to_cache, value=value,
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -271,7 +262,6 @@ class FlashCohereAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -281,7 +271,6 @@ class FlashCohereAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -356,10 +345,8 @@ class FlashCohereLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -371,10 +358,8 @@ class FlashCohereLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -424,10 +409,8 @@ class FlashCohereModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: torch.Tensor, seqlen: torch.Tensor,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -446,10 +429,8 @@ class FlashCohereModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -488,10 +469,8 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
@ -501,10 +480,8 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -308,10 +308,8 @@ class DbrxAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -329,14 +327,10 @@ class DbrxAttention(torch.nn.Module):
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -351,7 +345,6 @@ class DbrxAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -361,7 +354,6 @@ class DbrxAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -398,10 +390,8 @@ class DbrxNormAttentionNorm(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
normed_hidden_states, res = self.norm_1(hidden_states, residual) normed_hidden_states, res = self.norm_1(hidden_states, residual)
@ -413,10 +403,8 @@ class DbrxNormAttentionNorm(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -630,10 +618,8 @@ class DbrxLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
# Self Attention # Self Attention
@ -644,10 +630,8 @@ class DbrxLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -689,10 +673,8 @@ class DbrxModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -710,10 +692,8 @@ class DbrxModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -744,10 +724,8 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
@ -757,10 +735,8 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -256,10 +256,8 @@ class DeepseekV2Attention(torch.nn.Module):
sin: torch.Tensor, sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor, cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
if self.q_lora_rank is None: if self.q_lora_rank is None:
@ -316,15 +314,10 @@ class DeepseekV2Attention(torch.nn.Module):
value = torch.nn.functional.pad( value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0 value, (0, self.head_pad_size - self.value_head_size), value=0
) )
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store( kv_cache.store(
key=key_to_cache, key=key,
value=value_to_cache, value=value,
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -339,7 +332,6 @@ class DeepseekV2Attention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -349,7 +341,6 @@ class DeepseekV2Attention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -512,10 +503,8 @@ class DeepseekV2Layer(nn.Module):
sin: torch.Tensor, sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor, cu_seqlen_prefill: torch.Tensor,
kv_cache, kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -527,10 +516,8 @@ class DeepseekV2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -577,10 +564,8 @@ class DeepseekV2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -598,10 +583,8 @@ class DeepseekV2Model(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -629,10 +612,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
@ -642,10 +623,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -256,10 +256,8 @@ class DeepseekV3Attention(torch.nn.Module):
sin: torch.Tensor, sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor, cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
if self.q_lora_rank is None: if self.q_lora_rank is None:
@ -317,15 +315,9 @@ class DeepseekV3Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0 value, (0, self.head_pad_size - self.value_head_size), value=0
) )
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store( kv_cache.store(
key=key_to_cache, key=key,
value=value_to_cache, value=value,
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -340,7 +332,6 @@ class DeepseekV3Attention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -350,7 +341,6 @@ class DeepseekV3Attention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -522,10 +512,8 @@ class DeepseekV3Layer(nn.Module):
sin: torch.Tensor, sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor, cu_seqlen_prefill: torch.Tensor,
kv_cache, kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -537,10 +525,8 @@ class DeepseekV3Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -587,10 +573,8 @@ class DeepseekV3Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -608,10 +592,8 @@ class DeepseekV3Model(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -639,10 +621,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
@ -652,10 +632,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -235,11 +235,9 @@ class FlashGemma2Attention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
@ -254,14 +252,10 @@ class FlashGemma2Attention(torch.nn.Module):
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -276,7 +270,6 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
window_size_left=self.window_size, window_size_left=self.window_size,
softcap=self.softcap, softcap=self.softcap,
@ -288,7 +281,6 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
softcap=self.softcap, softcap=self.softcap,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
@ -402,11 +394,9 @@ class FlashGemma2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -418,11 +408,9 @@ class FlashGemma2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -472,11 +460,9 @@ class FlashGemma2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
adapter_data: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor],
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -494,11 +480,9 @@ class FlashGemma2Model(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -543,10 +527,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
@ -557,11 +539,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -207,10 +207,8 @@ class FlashGemmaAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -226,14 +224,9 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -248,7 +241,6 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
causal=self.causal, causal=self.causal,
) )
@ -259,7 +251,6 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -331,10 +322,8 @@ class FlashGemmaLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -346,10 +335,8 @@ class FlashGemmaLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -395,11 +382,9 @@ class FlashGemmaModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
adapter_data: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor],
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -417,10 +402,8 @@ class FlashGemmaModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -463,10 +446,8 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
@ -477,11 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -209,10 +209,8 @@ class FlashGPT2Attention(torch.nn.Module):
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
@ -222,16 +220,9 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store( kv_cache.store(
key=key_to_cache, key=key,
value=value_to_cache, value=value,
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -246,7 +237,6 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -256,7 +246,6 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -325,10 +314,8 @@ class FlashGPT2Layer(nn.Module):
residual, residual,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
residual = hidden_states residual = hidden_states
@ -339,10 +326,8 @@ class FlashGPT2Layer(nn.Module):
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -393,10 +378,8 @@ class FlashGPT2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -408,10 +391,8 @@ class FlashGPT2Model(torch.nn.Module):
residual, residual,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -446,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -462,10 +441,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices=prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -201,11 +201,9 @@ class FlashLlamaAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache: KVCache, kv_cache: KVCache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
@ -221,14 +219,9 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -243,7 +236,6 @@ class FlashLlamaAttention(torch.nn.Module):
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -253,7 +245,6 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -441,12 +432,10 @@ class FlashLlamaLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
cross_attention_states, cross_attention_states,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -458,11 +447,9 @@ class FlashLlamaLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if self.residual_multiplier is not None: if self.residual_multiplier is not None:
@ -554,10 +541,8 @@ class FlashLlamaModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data, adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None, cross_attention_states=None,
@ -577,12 +562,10 @@ class FlashLlamaModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
cross_attention_states, cross_attention_states,
prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
@ -643,30 +626,21 @@ class FlashLlamaForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None, cross_attention_states=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
inputs_embeds, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data, adapter_data=adapter_data,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,

View File

@ -169,11 +169,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
# Unused for this model # Unused for this model
@ -276,11 +274,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data, adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -178,10 +178,8 @@ class MistralAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
): ):
@ -198,14 +196,9 @@ class MistralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -220,7 +213,6 @@ class MistralAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -231,7 +223,6 @@ class MistralAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -335,10 +326,8 @@ class MistralLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
): ):
@ -351,10 +340,8 @@ class MistralLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
@ -403,10 +390,8 @@ class MistralModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
): ):
@ -424,10 +409,8 @@ class MistralModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
@ -475,30 +458,20 @@ class FlashMistralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
inputs_embeds, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
adapter_data, adapter_data,
) )

View File

@ -235,10 +235,8 @@ class MixtralAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -254,14 +252,9 @@ class MixtralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -276,7 +269,6 @@ class MixtralAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -287,7 +279,6 @@ class MixtralAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -384,10 +375,8 @@ class MixtralLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -399,10 +388,8 @@ class MixtralLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -454,10 +441,8 @@ class MixtralModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -475,10 +460,8 @@ class MixtralModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -510,29 +493,20 @@ class FlashMixtralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -801,12 +801,10 @@ class FlashLlamaCrossLayer(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
adapter_data, adapter_data,
cross_attention_states, # [ IB, ...] cross_attention_states, # [ IB, ...]
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None: if cross_attention_states is None:
@ -911,11 +909,9 @@ class FlashMllamaForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through. # XXX: Putting these as optional so that the cuda warmup calls can go through.
@ -979,11 +975,9 @@ class FlashMllamaForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
adapter_data=adapter_data, adapter_data=adapter_data,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,

View File

@ -147,10 +147,8 @@ class FlashNeoxAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -166,14 +164,10 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(query_rot, key_rot, cos, sin) self.rotary_emb(query_rot, key_rot, cos, sin)
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
if prefill_cache_indices is not None:
qkv_to_cache = qkv[prefill_cache_indices]
else:
qkv_to_cache = qkv
kv_cache.store( kv_cache.store(
key=qkv_to_cache[:, 1], key=qkv[:, 1],
value=qkv_to_cache[:, 2], value=qkv[:, 2],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -188,7 +182,6 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -198,7 +191,6 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -268,10 +260,8 @@ class FlashNeoXLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
@ -283,10 +273,8 @@ class FlashNeoXLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -308,10 +296,8 @@ class FlashNeoXLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -363,10 +349,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
@ -384,10 +368,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -417,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -430,10 +410,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -69,11 +69,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
# Unused here # Unused here
@ -106,11 +104,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data, adapter_data=adapter_data,
) )

View File

@ -160,10 +160,8 @@ class FlashPhiAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
# Compute query, key, value and split # Compute query, key, value and split
@ -190,13 +188,9 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Reshape key and value and cache # Reshape key and value and cache
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -210,7 +204,6 @@ class FlashPhiAttention(torch.nn.Module):
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -220,7 +213,6 @@ class FlashPhiAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -287,10 +279,8 @@ class FlashPhiLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
hidden_states, res = self.input_layernorm(hidden_states, residual) hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -301,10 +291,8 @@ class FlashPhiLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -354,10 +342,8 @@ class FlashPhiModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -375,10 +361,8 @@ class FlashPhiModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -409,10 +393,8 @@ class FlashPhiForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
@ -422,10 +404,8 @@ class FlashPhiForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -106,10 +106,8 @@ class Qwen2Attention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -125,14 +123,9 @@ class Qwen2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -147,7 +140,6 @@ class Qwen2Attention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -158,7 +150,6 @@ class Qwen2Attention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -230,10 +221,8 @@ class Qwen2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
normed_hidden_states, residual = self.input_layernorm(hidden_states) normed_hidden_states, residual = self.input_layernorm(hidden_states)
@ -245,10 +234,8 @@ class Qwen2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
hidden_states = attn_output + residual hidden_states = attn_output + residual
@ -296,10 +283,8 @@ class Qwen2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -317,10 +302,8 @@ class Qwen2Model(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -364,21 +347,13 @@ class Qwen2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None and prefill_cache_indices.size(
0
) != slots.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
@ -386,10 +361,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -182,10 +182,8 @@ class FlashRWAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -203,14 +201,9 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -225,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -235,7 +227,6 @@ class FlashRWAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -309,10 +300,8 @@ class FlashRWLargeAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -329,14 +318,9 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, :, 0].contiguous(), key=kv[:, :, 0].contiguous(),
value=kv_to_cache[:, :, 1].contiguous(), value=kv[:, :, 1].contiguous(),
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -351,7 +335,6 @@ class FlashRWLargeAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -361,7 +344,6 @@ class FlashRWLargeAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -447,10 +429,8 @@ class FlashRWLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
if self.parallel_attn: if self.parallel_attn:
@ -462,10 +442,8 @@ class FlashRWLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -485,10 +463,8 @@ class FlashRWLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -573,10 +549,8 @@ class FlashRWLargeLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
# Layer norm. # Layer norm.
@ -589,10 +563,8 @@ class FlashRWLargeLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -651,10 +623,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
@ -672,10 +642,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -703,10 +671,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
@ -716,10 +682,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -265,10 +265,8 @@ class FlashMQAttention(torch.nn.Module):
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
@ -282,14 +280,9 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
if prefill_cache_indices is not None:
key_value_to_cache = key_value[prefill_cache_indices]
else:
key_value_to_cache = key_value
kv_cache.store( kv_cache.store(
key=key_value_to_cache[:, 0], key=key_value[:, 0],
value=key_value_to_cache[:, 1], value=key_value[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -304,7 +297,6 @@ class FlashMQAttention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode # Decode
@ -314,7 +306,6 @@ class FlashMQAttention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -379,10 +370,8 @@ class Block(nn.Module):
residual, residual,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
@ -390,10 +379,8 @@ class Block(nn.Module):
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -445,10 +432,8 @@ class FlashSantacoderModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -463,10 +448,8 @@ class FlashSantacoderModel(nn.Module):
residual, residual,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
@ -496,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -509,10 +490,8 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
hpu_attention_meta, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -235,10 +235,8 @@ class Starcoder2Attention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
): ):
@ -255,14 +253,9 @@ class Starcoder2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_cache.store( kv_cache.store(
key=kv_to_cache[:, 0], key=kv[:, 0],
value=kv_to_cache[:, 1], value=kv[:, 1],
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
@ -277,7 +270,6 @@ class Starcoder2Attention(torch.nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -288,7 +280,6 @@ class Starcoder2Attention(torch.nn.Module):
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -448,10 +439,8 @@ class Starcoder2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
): ):
@ -464,10 +453,8 @@ class Starcoder2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
@ -518,10 +505,8 @@ class Starcoder2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data, adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
@ -540,10 +525,8 @@ class Starcoder2Model(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
@ -589,29 +572,20 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None and slots.size(
0
) != prefill_cache_indices.size(0):
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )

View File

@ -740,11 +740,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None,
@ -843,11 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data, adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -483,11 +483,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None,
@ -587,11 +585,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data, adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -906,11 +906,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
@ -938,11 +936,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -480,11 +480,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
@ -511,11 +509,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -25,7 +25,7 @@ from typing import (
Dict, Dict,
Union, Union,
) )
import torch.nn.functional as F
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.models import Model from text_generation_server.models import Model
@ -116,7 +116,6 @@ def prepare_for_decode(
block_list = flatten(block_tables) block_list = flatten(block_tables)
block_groups = flatten(block_groups) block_groups = flatten(block_groups)
block_usage = flatten(block_usage) block_usage = flatten(block_usage)
assert len(block_list) == len(block_groups) assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage) assert len(block_list) == len(block_usage)
if use_contiguous_pa: if use_contiguous_pa:
@ -979,29 +978,27 @@ class FlashCausalLMBatch(Batch):
# padding to left to work with sliding window # padding to left to work with sliding window
# use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate
# the right logit position # the right logit position
input_ids_padded = None input_ids_padded_length = []
input_ids_padded_length = None # need extra pad to match warmup seq
extra_pad = 0
if isinstance(self.input_ids, list) and len(self) > 1: if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded = []
input_ids_padded_length = [] input_ids_padded_length = []
input_ids = []
for input_id in self.input_ids: for input_id in self.input_ids:
padded = self.max_input_length - len(input_id) padded = self.max_input_length - len(input_id) + extra_pad
input_id_padded = input_id
if padded > 0: if padded > 0:
input_id_padded = [0] * padded + input_id_padded input_id = [0] * padded + input_id
input_ids_padded.append(input_id_padded) input_ids.append(input_id)
input_ids_padded_length.append(padded) input_ids_padded_length.append(padded)
input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64) input_ids = np.concatenate(input_ids, dtype=np.int64)
input_ids_padded = torch.tensor(
input_ids_padded, dtype=torch.int64, device=device
)
if isinstance(self.input_ids, list):
if len(self) > 1:
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
else:
input_ids = self.input_ids[0]
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad)
input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else:
logger.error("should not be here, prefill self.input_ids is a tensor")
self.input_lengths_tensor = torch.tensor( self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device self.input_lengths, dtype=torch.int32, device=device
@ -1052,10 +1049,9 @@ class FlashCausalLMBatch(Batch):
request_position_ids = torch.arange( request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32 cache_length, cache_length + input_length, dtype=torch.int32
) )
if input_ids_padded is not None: request_position_ids = F.pad(
position_ids.append( request_position_ids, (input_ids_padded_length[i], 0), value=1
torch.ones(input_ids_padded_length[i], dtype=torch.int32) )
)
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
if not r.slots: if not r.slots:
@ -1079,12 +1075,11 @@ class FlashCausalLMBatch(Batch):
cumulative_slot_tokens += len(request_slots) cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill # Create tensor to slice into the kv tensor in prefill
if input_ids_padded is not None: # hpu need request_prefill_cache_indices to skip padding in kv cache
# hpu need request_prefill_cache_indices to skip padding in kv cache sliding_window = get_sliding_windows()
sliding_window = get_sliding_windows() if sliding_window is None:
if sliding_window is None: sliding_window = input_length
sliding_window = input_length cumulative_length += input_ids_padded_length[i]
cumulative_length += input_ids_padded_length[i]
if sliding_window is not None: if sliding_window is not None:
request_prefill_cache_indices = torch.arange( request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window), cumulative_length + max(0, input_length - sliding_window),
@ -1105,8 +1100,7 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
if sliding_window is not None: prefill_cache_indices.append(request_prefill_cache_indices)
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index() ADAPTER_TO_INDEX = get_adapter_to_index()
if ADAPTER_TO_INDEX: if ADAPTER_TO_INDEX:
@ -1171,23 +1165,20 @@ class FlashCausalLMBatch(Batch):
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
if slot_indices: if slot_indices:
slot_indices = torch.cat(slot_indices) slot_indices = torch.cat(slot_indices)
if sliding_window is not None: prefill_cache_indices = torch.cat(prefill_cache_indices)
prefill_cache_indices = torch.cat(prefill_cache_indices)
else: else:
if position_ids: if position_ids:
position_ids = position_ids[0] position_ids = position_ids[0]
if slot_indices: if slot_indices:
slot_indices = slot_indices[0] slot_indices = slot_indices[0]
if sliding_window is not None: prefill_cache_indices = prefill_cache_indices[0]
prefill_cache_indices = prefill_cache_indices[0]
self.position_ids = position_ids.to(device) self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device) self.slot_indices = slot_indices.to(device)
self.prefill_cu_outlens = prefill_cu_outlens self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = ( self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
prefill_cache_indices.to(device) if sliding_window is not None else None self.prefill_cache_indices[prefill_cache_indices.to(device)] = True
)
if all_prefill_logprobs: if all_prefill_logprobs:
prefill_head_indices = None prefill_head_indices = None
@ -1203,21 +1194,19 @@ class FlashCausalLMBatch(Batch):
self.prefill_head_indices = prefill_head_indices self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices self.prefill_next_token_indices = prefill_next_token_indices
if input_ids_padded is not None: input_ids_padded_length_tensor = torch.cumsum(
self.input_ids = input_ids_padded torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),
input_ids_padded_length_tensor = torch.cumsum( dim=-1,
torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), )
dim=-1, if self.prefill_head_indices is not None:
self.prefill_head_indices = (
self.prefill_head_indices + input_ids_padded_length_tensor
) )
if self.prefill_head_indices is not None:
self.prefill_head_indices = (
self.prefill_head_indices + input_ids_padded_length_tensor
)
if self.prefill_next_token_indices is not None: if self.prefill_next_token_indices is not None:
self.prefill_next_token_indices = ( self.prefill_next_token_indices = (
self.prefill_next_token_indices + input_ids_padded_length_tensor self.prefill_next_token_indices + input_ids_padded_length_tensor
) )
if adapter_set: if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to( adapter_indices = torch.cat(adapter_indices_list).to(
@ -1232,7 +1221,6 @@ class FlashCausalLMBatch(Batch):
adapter_segments = torch.tensor( adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device adapter_segments, dtype=torch.int32, device=device
) )
self.adapter_meta = AdapterBatchMetadata( self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices, adapter_indices=adapter_indices,
adapter_set=adapter_set, adapter_set=adapter_set,
@ -1490,14 +1478,6 @@ class FlashCausalLM(Model):
self.kv_cache_dtype, self.kv_cache_dtype,
self.device, self.device,
) )
for bs in [1, 2, 4, 8]:
for seqlen in [32, 64, 128, 256, 512, 1024]:
self.warmup_prefill(seqlen, bs)
for bs in [1, 2, 4, 8]:
for block_num in [1, 2, 4, 8, 16]:
self.warmup_decode(bs, block_num * bs)
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def warmup_prefill(self, prompt_len: int, bs: int): def warmup_prefill(self, prompt_len: int, bs: int):
@ -1539,10 +1519,8 @@ class FlashCausalLM(Model):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=None,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
adapter_data=None, adapter_data=None,
hpu_attention_meta=None, hpu_attention_meta=None,
@ -1562,7 +1540,6 @@ class FlashCausalLM(Model):
for i in range(bs): for i in range(bs):
slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
slots = torch.tensor(slots, dtype=torch.int64, device=self.device) slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
cache_lengths_tensor = ( cache_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * past_len torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
@ -1575,11 +1552,11 @@ class FlashCausalLM(Model):
cache_lengths=cache_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
) )
block_num = cache_lengths_tensor // BLOCK_SIZE + 1 block_num = cache_lengths_tensor // BLOCK_SIZE + 1
block_tables_valid = [] block_tables_valid = []
for i, bt in enumerate(block_tables.tolist()): for i, bt in enumerate(block_tables.tolist()):
block_tables_valid.append(bt[0 : block_num[i]]) block_tables_valid.append(bt[0 : block_num[i]])
hpu_attention_meta = prepare_for_decode( hpu_attention_meta = prepare_for_decode(
self.dtype, self.dtype,
self.use_contiguous_pa, self.use_contiguous_pa,
@ -1595,10 +1572,8 @@ class FlashCausalLM(Model):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
adapter_data=None, adapter_data=None,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
@ -1684,26 +1659,23 @@ class FlashCausalLM(Model):
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False kwargs["bypass_hpu_graphs"] = False
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=None,
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
# TODO not support adapter now, need the add in the future # TODO not support adapter now, need the add in the future
adapter_data=None, adapter_data=None,
hpu_attention_meta=batch.hpu_attn_meta, hpu_attention_meta=batch.hpu_attn_meta,
**kwargs, **kwargs,
) )
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
# fix following runtime error in graph replay
# RuntimeError: Neither storage attached to input tensor, not its view
htorch.core.mark_step()
return logits, speculative_logits return logits, speculative_logits
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
@ -1801,7 +1773,14 @@ class FlashCausalLM(Model):
# instantly become of shape [BATCH_SIZE] # instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling: if prefill and finished_prefilling:
indices = batch.cu_seqlen_prefill[1:] - 1 indices = batch.cu_seqlen_prefill[1:] - 1
batch.position_ids = batch.position_ids[indices] # pad in left
if batch.prefill_cache_indices is not None:
batch.position_ids = batch.position_ids[batch.prefill_cache_indices][
indices
]
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices] batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices indices

View File

@ -462,11 +462,9 @@ class FlashVlmCausalLM(FlashCausalLM):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta, hpu_attention_meta=batch.hpu_attn_meta,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values, pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask, pixel_attention_mask=batch.pixel_attention_mask,

View File

@ -288,11 +288,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta, hpu_attention_meta=batch.hpu_attn_meta,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
# TODO list # TODO list