diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index eb98a756..7c7096f7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -180,16 +180,11 @@ class FlashPhiAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - # Apply partial rotary embedding and store the end of the embedding - query_pass = query[:, :, self.rotary_emb_dim:] - key_pass = torch.select(kv, dim=1, index=0)[:, :, self.rotary_emb_dim:] - - # Apply in place positional rotary embeddings - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - - # Restore the query and key from the partial rotary embedding - kv[:, 0, :, self.rotary_emb_dim:] = key_pass - query[:, :, self.rotary_emb_dim:] = query_pass + # Apply partial positional embeddings in place + self.rotary_emb( + query[:, :, :self.rotary_emb_dim], kv[:, 0, :, :self.rotary_emb_dim], + cos, sin + ) # Reshape key and value and cache paged_attention.reshape_and_cache(