mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-04 16:32:13 +00:00
remove block_tables and prefill_cache_indices which will lead to dynamic shape
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
7900be5ac3
commit
1508ee8de1
@ -26,7 +26,6 @@ def attention(
|
||||
kv_cache: KVCache,
|
||||
kv_scales: KVScales,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
@ -61,7 +60,6 @@ def paged_attention(
|
||||
kv_cache: KVCache,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
*,
|
||||
kv_scales: KVScales,
|
||||
|
@ -219,10 +219,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -247,16 +245,9 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
key_to_cache = key[prefill_cache_indices]
|
||||
value_to_cache = value[prefill_cache_indices]
|
||||
else:
|
||||
key_to_cache = key
|
||||
value_to_cache = value
|
||||
|
||||
kv_cache.store(
|
||||
key=key_to_cache,
|
||||
value=value_to_cache,
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -271,7 +262,6 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -281,7 +271,6 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -356,10 +345,8 @@ class FlashCohereLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -371,10 +358,8 @@ class FlashCohereLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -424,10 +409,8 @@ class FlashCohereModel(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: torch.Tensor,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -446,10 +429,8 @@ class FlashCohereModel(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -488,10 +469,8 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -501,10 +480,8 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -308,10 +308,8 @@ class DbrxAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -329,14 +327,10 @@ class DbrxAttention(torch.nn.Module):
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -351,7 +345,6 @@ class DbrxAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -361,7 +354,6 @@ class DbrxAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -398,10 +390,8 @@ class DbrxNormAttentionNorm(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
normed_hidden_states, res = self.norm_1(hidden_states, residual)
|
||||
@ -413,10 +403,8 @@ class DbrxNormAttentionNorm(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -630,10 +618,8 @@ class DbrxLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
# Self Attention
|
||||
@ -644,10 +630,8 @@ class DbrxLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -689,10 +673,8 @@ class DbrxModel(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -710,10 +692,8 @@ class DbrxModel(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -744,10 +724,8 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -757,10 +735,8 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -256,10 +256,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
sin: torch.Tensor,
|
||||
cu_seqlen_prefill: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
if self.q_lora_rank is None:
|
||||
@ -316,15 +314,10 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
if prefill_cache_indices is not None:
|
||||
key_to_cache = key[prefill_cache_indices]
|
||||
value_to_cache = value[prefill_cache_indices]
|
||||
else:
|
||||
key_to_cache = key
|
||||
value_to_cache = value
|
||||
|
||||
kv_cache.store(
|
||||
key=key_to_cache,
|
||||
value=value_to_cache,
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -339,7 +332,6 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -349,7 +341,6 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -512,10 +503,8 @@ class DeepseekV2Layer(nn.Module):
|
||||
sin: torch.Tensor,
|
||||
cu_seqlen_prefill: torch.Tensor,
|
||||
kv_cache,
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
@ -527,10 +516,8 @@ class DeepseekV2Layer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -577,10 +564,8 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -598,10 +583,8 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -629,10 +612,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -642,10 +623,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -256,10 +256,8 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
sin: torch.Tensor,
|
||||
cu_seqlen_prefill: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
if self.q_lora_rank is None:
|
||||
@ -317,15 +315,9 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
key_to_cache = key[prefill_cache_indices]
|
||||
value_to_cache = value[prefill_cache_indices]
|
||||
else:
|
||||
key_to_cache = key
|
||||
value_to_cache = value
|
||||
kv_cache.store(
|
||||
key=key_to_cache,
|
||||
value=value_to_cache,
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -340,7 +332,6 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -350,7 +341,6 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -522,10 +512,8 @@ class DeepseekV3Layer(nn.Module):
|
||||
sin: torch.Tensor,
|
||||
cu_seqlen_prefill: torch.Tensor,
|
||||
kv_cache,
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
@ -537,10 +525,8 @@ class DeepseekV3Layer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -587,10 +573,8 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -608,10 +592,8 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -639,10 +621,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -652,10 +632,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -235,11 +235,9 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
@ -254,14 +252,10 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -276,7 +270,6 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.window_size,
|
||||
softcap=self.softcap,
|
||||
@ -288,7 +281,6 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
softcap=self.softcap,
|
||||
kv_scales=self.kv_scales,
|
||||
@ -402,11 +394,9 @@ class FlashGemma2Layer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -418,11 +408,9 @@ class FlashGemma2Layer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -472,11 +460,9 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
adapter_data: Optional[torch.Tensor],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
@ -494,11 +480,9 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -543,10 +527,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -557,11 +539,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -207,10 +207,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -226,14 +224,9 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -248,7 +241,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=self.causal,
|
||||
)
|
||||
@ -259,7 +251,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -331,10 +322,8 @@ class FlashGemmaLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -346,10 +335,8 @@ class FlashGemmaLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -395,11 +382,9 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
adapter_data: Optional[torch.Tensor],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
@ -417,10 +402,8 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -463,10 +446,8 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -477,11 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -209,10 +209,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
query, key, value = self.query_key_value(hidden_states).split(
|
||||
@ -222,16 +220,9 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
key_to_cache = key[prefill_cache_indices]
|
||||
value_to_cache = value[prefill_cache_indices]
|
||||
else:
|
||||
key_to_cache = key
|
||||
value_to_cache = value
|
||||
|
||||
kv_cache.store(
|
||||
key=key_to_cache,
|
||||
value=value_to_cache,
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -246,7 +237,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -256,7 +246,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -325,10 +314,8 @@ class FlashGPT2Layer(nn.Module):
|
||||
residual,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
residual = hidden_states
|
||||
@ -339,10 +326,8 @@ class FlashGPT2Layer(nn.Module):
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -393,10 +378,8 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
@ -408,10 +391,8 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
residual,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -446,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
@ -462,10 +441,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -201,11 +201,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache: KVCache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
@ -221,14 +219,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -243,7 +236,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
kv_scales=self.kv_scales,
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -253,7 +245,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -441,12 +432,10 @@ class FlashLlamaLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
cross_attention_states,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -458,11 +447,9 @@ class FlashLlamaLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
if self.residual_multiplier is not None:
|
||||
@ -554,10 +541,8 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
cross_attention_states=None,
|
||||
@ -577,12 +562,10 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
cross_attention_states,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -643,30 +626,21 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if prefill_cache_indices is not None and slots.size(
|
||||
0
|
||||
) != prefill_cache_indices.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
adapter_data=adapter_data,
|
||||
cross_attention_states=cross_attention_states,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
|
@ -169,11 +169,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
@ -276,11 +274,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -178,10 +178,8 @@ class MistralAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
@ -198,14 +196,9 @@ class MistralAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -220,7 +213,6 @@ class MistralAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
@ -231,7 +223,6 @@ class MistralAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -335,10 +326,8 @@ class MistralLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
@ -351,10 +340,8 @@ class MistralLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
@ -403,10 +390,8 @@ class MistralModel(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
@ -424,10 +409,8 @@ class MistralModel(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
@ -475,30 +458,20 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None and slots.size(
|
||||
0
|
||||
) != prefill_cache_indices.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
adapter_data,
|
||||
)
|
||||
|
@ -235,10 +235,8 @@ class MixtralAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -254,14 +252,9 @@ class MixtralAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -276,7 +269,6 @@ class MixtralAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
@ -287,7 +279,6 @@ class MixtralAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -384,10 +375,8 @@ class MixtralLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -399,10 +388,8 @@ class MixtralLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -454,10 +441,8 @@ class MixtralModel(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -475,10 +460,8 @@ class MixtralModel(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -510,29 +493,20 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None and slots.size(
|
||||
0
|
||||
) != prefill_cache_indices.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -801,12 +801,10 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
cross_attention_states, # [ IB, ...]
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if cross_attention_states is None:
|
||||
@ -911,11 +909,9 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
||||
@ -979,11 +975,9 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=adapter_data,
|
||||
cross_attention_states=cross_attention_states,
|
||||
|
@ -147,10 +147,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -166,14 +164,10 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.rotary_emb(query_rot, key_rot, cos, sin)
|
||||
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
||||
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
||||
if prefill_cache_indices is not None:
|
||||
qkv_to_cache = qkv[prefill_cache_indices]
|
||||
else:
|
||||
qkv_to_cache = qkv
|
||||
|
||||
kv_cache.store(
|
||||
key=qkv_to_cache[:, 1],
|
||||
value=qkv_to_cache[:, 2],
|
||||
key=qkv[:, 1],
|
||||
value=qkv[:, 2],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -188,7 +182,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -198,7 +191,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -268,10 +260,8 @@ class FlashNeoXLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
if self.use_parallel_residual:
|
||||
@ -283,10 +273,8 @@ class FlashNeoXLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -308,10 +296,8 @@ class FlashNeoXLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -363,10 +349,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
@ -384,10 +368,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -417,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -430,10 +410,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -69,11 +69,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused here
|
||||
@ -106,11 +104,9 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
|
||||
|
@ -160,10 +160,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
# Compute query, key, value and split
|
||||
@ -190,13 +188,9 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
# Reshape key and value and cache
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -210,7 +204,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
kv_scales=self.kv_scales,
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -220,7 +213,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -287,10 +279,8 @@ class FlashPhiLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -301,10 +291,8 @@ class FlashPhiLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -354,10 +342,8 @@ class FlashPhiModel(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -375,10 +361,8 @@ class FlashPhiModel(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -409,10 +393,8 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -422,10 +404,8 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -106,10 +106,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -125,14 +123,9 @@ class Qwen2Attention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -147,7 +140,6 @@ class Qwen2Attention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
@ -158,7 +150,6 @@ class Qwen2Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -230,10 +221,8 @@ class Qwen2Layer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
normed_hidden_states, residual = self.input_layernorm(hidden_states)
|
||||
@ -245,10 +234,8 @@ class Qwen2Layer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
hidden_states = attn_output + residual
|
||||
@ -296,10 +283,8 @@ class Qwen2Model(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
@ -317,10 +302,8 @@ class Qwen2Model(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -364,21 +347,13 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if prefill_cache_indices is not None and prefill_cache_indices.size(
|
||||
0
|
||||
) != slots.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = self.model(
|
||||
@ -386,10 +361,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -182,10 +182,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -203,14 +201,9 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -225,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -235,7 +227,6 @@ class FlashRWAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -309,10 +300,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
@ -329,14 +318,9 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, :, 0].contiguous(),
|
||||
value=kv_to_cache[:, :, 1].contiguous(),
|
||||
key=kv[:, :, 0].contiguous(),
|
||||
value=kv[:, :, 1].contiguous(),
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -351,7 +335,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -361,7 +344,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -447,10 +429,8 @@ class FlashRWLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
@ -462,10 +442,8 @@ class FlashRWLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -485,10 +463,8 @@ class FlashRWLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -573,10 +549,8 @@ class FlashRWLargeLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
# Layer norm.
|
||||
@ -589,10 +563,8 @@ class FlashRWLargeLayer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -651,10 +623,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
@ -672,10 +642,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -703,10 +671,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -716,10 +682,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -265,10 +265,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
qkv = self.c_attn(hidden_states)
|
||||
@ -282,14 +280,9 @@ class FlashMQAttention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
key_value_to_cache = key_value[prefill_cache_indices]
|
||||
else:
|
||||
key_value_to_cache = key_value
|
||||
|
||||
kv_cache.store(
|
||||
key=key_value_to_cache[:, 0],
|
||||
value=key_value_to_cache[:, 1],
|
||||
key=key_value[:, 0],
|
||||
value=key_value[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -304,7 +297,6 @@ class FlashMQAttention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
@ -314,7 +306,6 @@ class FlashMQAttention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -379,10 +370,8 @@ class Block(nn.Module):
|
||||
residual,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
@ -390,10 +379,8 @@ class Block(nn.Module):
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -445,10 +432,8 @@ class FlashSantacoderModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
@ -463,10 +448,8 @@ class FlashSantacoderModel(nn.Module):
|
||||
residual,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
||||
@ -496,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -509,10 +490,8 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -235,10 +235,8 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
@ -255,14 +253,9 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
@ -277,7 +270,6 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
@ -288,7 +280,6 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -448,10 +439,8 @@ class Starcoder2Layer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
):
|
||||
@ -464,10 +453,8 @@ class Starcoder2Layer(nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
@ -518,10 +505,8 @@ class Starcoder2Model(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
@ -540,10 +525,8 @@ class Starcoder2Model(torch.nn.Module):
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
@ -589,29 +572,20 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None and slots.size(
|
||||
0
|
||||
) != prefill_cache_indices.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
@ -740,11 +740,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
@ -843,11 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -483,11 +483,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
@ -587,11 +585,9 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -906,11 +906,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
@ -938,11 +936,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -480,11 +480,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
@ -511,11 +509,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -25,7 +25,7 @@ from typing import (
|
||||
Dict,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch.nn.functional as F
|
||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.models import Model
|
||||
@ -116,7 +116,6 @@ def prepare_for_decode(
|
||||
block_list = flatten(block_tables)
|
||||
block_groups = flatten(block_groups)
|
||||
block_usage = flatten(block_usage)
|
||||
|
||||
assert len(block_list) == len(block_groups)
|
||||
assert len(block_list) == len(block_usage)
|
||||
if use_contiguous_pa:
|
||||
@ -979,29 +978,27 @@ class FlashCausalLMBatch(Batch):
|
||||
# padding to left to work with sliding window
|
||||
# use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate
|
||||
# the right logit position
|
||||
input_ids_padded = None
|
||||
input_ids_padded_length = None
|
||||
input_ids_padded_length = []
|
||||
# need extra pad to match warmup seq
|
||||
extra_pad = 0
|
||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||
input_ids_padded = []
|
||||
input_ids_padded_length = []
|
||||
input_ids = []
|
||||
for input_id in self.input_ids:
|
||||
padded = self.max_input_length - len(input_id)
|
||||
input_id_padded = input_id
|
||||
padded = self.max_input_length - len(input_id) + extra_pad
|
||||
if padded > 0:
|
||||
input_id_padded = [0] * padded + input_id_padded
|
||||
input_ids_padded.append(input_id_padded)
|
||||
input_id = [0] * padded + input_id
|
||||
input_ids.append(input_id)
|
||||
input_ids_padded_length.append(padded)
|
||||
input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64)
|
||||
input_ids_padded = torch.tensor(
|
||||
input_ids_padded, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
if isinstance(self.input_ids, list):
|
||||
if len(self) > 1:
|
||||
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = self.input_ids[0]
|
||||
input_ids = np.concatenate(input_ids, dtype=np.int64)
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
elif isinstance(self.input_ids, list):
|
||||
input_ids = self.input_ids[0]
|
||||
input_ids_padded_length.append(extra_pad)
|
||||
input_ids = [0] * extra_pad + input_ids
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
else:
|
||||
logger.error("should not be here, prefill self.input_ids is a tensor")
|
||||
|
||||
self.input_lengths_tensor = torch.tensor(
|
||||
self.input_lengths, dtype=torch.int32, device=device
|
||||
@ -1052,10 +1049,9 @@ class FlashCausalLMBatch(Batch):
|
||||
request_position_ids = torch.arange(
|
||||
cache_length, cache_length + input_length, dtype=torch.int32
|
||||
)
|
||||
if input_ids_padded is not None:
|
||||
position_ids.append(
|
||||
torch.ones(input_ids_padded_length[i], dtype=torch.int32)
|
||||
)
|
||||
request_position_ids = F.pad(
|
||||
request_position_ids, (input_ids_padded_length[i], 0), value=1
|
||||
)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
if not r.slots:
|
||||
@ -1079,12 +1075,11 @@ class FlashCausalLMBatch(Batch):
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if input_ids_padded is not None:
|
||||
# hpu need request_prefill_cache_indices to skip padding in kv cache
|
||||
sliding_window = get_sliding_windows()
|
||||
if sliding_window is None:
|
||||
sliding_window = input_length
|
||||
cumulative_length += input_ids_padded_length[i]
|
||||
# hpu need request_prefill_cache_indices to skip padding in kv cache
|
||||
sliding_window = get_sliding_windows()
|
||||
if sliding_window is None:
|
||||
sliding_window = input_length
|
||||
cumulative_length += input_ids_padded_length[i]
|
||||
if sliding_window is not None:
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, input_length - sliding_window),
|
||||
@ -1105,8 +1100,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
if ADAPTER_TO_INDEX:
|
||||
@ -1171,23 +1165,20 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = torch.cat(position_ids)
|
||||
if slot_indices:
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
if position_ids:
|
||||
position_ids = position_ids[0]
|
||||
if slot_indices:
|
||||
slot_indices = slot_indices[0]
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
self.position_ids = position_ids.to(device)
|
||||
self.slot_indices = slot_indices.to(device)
|
||||
|
||||
self.prefill_cu_outlens = prefill_cu_outlens
|
||||
self.prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||
)
|
||||
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
|
||||
self.prefill_cache_indices[prefill_cache_indices.to(device)] = True
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
@ -1203,21 +1194,19 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
self.prefill_head_indices = prefill_head_indices
|
||||
self.prefill_next_token_indices = prefill_next_token_indices
|
||||
if input_ids_padded is not None:
|
||||
self.input_ids = input_ids_padded
|
||||
input_ids_padded_length_tensor = torch.cumsum(
|
||||
torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),
|
||||
dim=-1,
|
||||
input_ids_padded_length_tensor = torch.cumsum(
|
||||
torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),
|
||||
dim=-1,
|
||||
)
|
||||
if self.prefill_head_indices is not None:
|
||||
self.prefill_head_indices = (
|
||||
self.prefill_head_indices + input_ids_padded_length_tensor
|
||||
)
|
||||
if self.prefill_head_indices is not None:
|
||||
self.prefill_head_indices = (
|
||||
self.prefill_head_indices + input_ids_padded_length_tensor
|
||||
)
|
||||
|
||||
if self.prefill_next_token_indices is not None:
|
||||
self.prefill_next_token_indices = (
|
||||
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
||||
)
|
||||
if self.prefill_next_token_indices is not None:
|
||||
self.prefill_next_token_indices = (
|
||||
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
||||
)
|
||||
|
||||
if adapter_set:
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||
@ -1232,7 +1221,6 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
@ -1490,14 +1478,6 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
for bs in [1, 2, 4, 8]:
|
||||
for seqlen in [32, 64, 128, 256, 512, 1024]:
|
||||
self.warmup_prefill(seqlen, bs)
|
||||
|
||||
for bs in [1, 2, 4, 8]:
|
||||
for block_num in [1, 2, 4, 8, 16]:
|
||||
self.warmup_decode(bs, block_num * bs)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def warmup_prefill(self, prompt_len: int, bs: int):
|
||||
@ -1539,10 +1519,8 @@ class FlashCausalLM(Model):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=None,
|
||||
hpu_attention_meta=None,
|
||||
@ -1562,7 +1540,6 @@ class FlashCausalLM(Model):
|
||||
for i in range(bs):
|
||||
slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
|
||||
@ -1575,11 +1552,11 @@ class FlashCausalLM(Model):
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
|
||||
block_num = cache_lengths_tensor // BLOCK_SIZE + 1
|
||||
block_tables_valid = []
|
||||
for i, bt in enumerate(block_tables.tolist()):
|
||||
block_tables_valid.append(bt[0 : block_num[i]])
|
||||
|
||||
hpu_attention_meta = prepare_for_decode(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
@ -1595,10 +1572,8 @@ class FlashCausalLM(Model):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
adapter_data=None,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
@ -1684,26 +1659,23 @@ class FlashCausalLM(Model):
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=None,
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
# TODO not support adapter now, need the add in the future
|
||||
adapter_data=None,
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
**kwargs,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
# fix following runtime error in graph replay
|
||||
# RuntimeError: Neither storage attached to input tensor, not its view
|
||||
htorch.core.mark_step()
|
||||
return logits, speculative_logits
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
@ -1801,7 +1773,14 @@ class FlashCausalLM(Model):
|
||||
# instantly become of shape [BATCH_SIZE]
|
||||
if prefill and finished_prefilling:
|
||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
# pad in left
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.position_ids = batch.position_ids[batch.prefill_cache_indices][
|
||||
indices
|
||||
]
|
||||
else:
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
|
||||
batch.slot_indices = batch.slot_indices[indices]
|
||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||
indices
|
||||
|
@ -462,11 +462,9 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
|
@ -288,11 +288,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
cross_attention_states=cross_attention_states,
|
||||
# TODO list
|
||||
|
Loading…
Reference in New Issue
Block a user