mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
missing gptj change...
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
787dbe98a8
commit
376e0507b7
@ -156,10 +156,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
prefill_cache_indices,
|
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
query, key, value = self.query_key_value(hidden_states).split(
|
query, key, value = self.query_key_value(hidden_states).split(
|
||||||
@ -177,16 +175,9 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
|
||||||
key_to_cache = key[prefill_cache_indices]
|
|
||||||
value_to_cache = value[prefill_cache_indices]
|
|
||||||
else:
|
|
||||||
key_to_cache = key
|
|
||||||
value_to_cache = value
|
|
||||||
|
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
key=key_to_cache,
|
key=key,
|
||||||
value=value_to_cache,
|
value=value,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
@ -201,7 +192,6 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
@ -211,7 +201,6 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
@ -272,10 +261,8 @@ class FlashGPTJLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
prefill_cache_indices,
|
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
@ -286,10 +273,8 @@ class FlashGPTJLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
prefill_cache_indices,
|
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -334,10 +319,8 @@ class FlashGPTJModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.wte(input_ids)
|
hidden_states = self.wte(input_ids)
|
||||||
@ -355,10 +338,8 @@ class FlashGPTJModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
prefill_cache_indices,
|
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -387,10 +368,8 @@ class FlashGPTJForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
@ -400,10 +379,8 @@ class FlashGPTJForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user