missing gptj change...

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-28 01:08:40 -07:00
parent 787dbe98a8
commit 376e0507b7

View File

@ -156,10 +156,8 @@ class FlashGPTJAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
query, key, value = self.query_key_value(hidden_states).split(
@ -177,16 +175,9 @@ class FlashGPTJAttention(torch.nn.Module):
else:
self.rotary_emb(query, key, cos, sin)
if prefill_cache_indices is not None:
key_to_cache = key[prefill_cache_indices]
value_to_cache = value[prefill_cache_indices]
else:
key_to_cache = key
value_to_cache = value
kv_cache.store(
key=key_to_cache,
value=value_to_cache,
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
@ -201,7 +192,6 @@ class FlashGPTJAttention(torch.nn.Module):
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
@ -211,7 +201,6 @@ class FlashGPTJAttention(torch.nn.Module):
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
@ -272,10 +261,8 @@ class FlashGPTJLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -286,10 +273,8 @@ class FlashGPTJLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -334,10 +319,8 @@ class FlashGPTJModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
@ -355,10 +338,8 @@ class FlashGPTJModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
prefill_cache_indices,
hpu_attention_meta,
)
@ -387,10 +368,8 @@ class FlashGPTJForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
@ -400,10 +379,8 @@ class FlashGPTJForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
prefill_cache_indices=prefill_cache_indices,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None: