diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 3135acde..41eeab78 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -156,10 +156,8 @@ class FlashGPTJAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( @@ -177,16 +175,9 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value - kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -201,7 +192,6 @@ class FlashGPTJAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -211,7 +201,6 @@ class FlashGPTJAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -272,10 +261,8 @@ class FlashGPTJLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -286,10 +273,8 @@ class FlashGPTJLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -334,10 +319,8 @@ class FlashGPTJModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) @@ -355,10 +338,8 @@ class FlashGPTJModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -387,10 +368,8 @@ class FlashGPTJForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -400,10 +379,8 @@ class FlashGPTJForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices=prefill_cache_indices, hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: