From 92a74ea0364ed8785ca17a4f64f347c114746f4b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:54:23 +0200 Subject: [PATCH] revert some changes --- .../custom_modeling/flash_llama_modeling.py | 26 +++++++++---------- .../custom_modeling/flash_neox_modeling.py | 26 +++++++++---------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0f1a1a54..2ea88e9d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -140,25 +140,22 @@ class FlashLlamaAttention(torch.nn.Module): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - query, kv = qkv.split([1, 2], dim=1) - query = query.view(-1, self.num_heads, self.head_size) - # Inplace rotary - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(qkv[:, 0], cos, sin) + self.rotary_emb(qkv[:, 1], cos, sin) # Prefill if prefill: # Copy to layer past - layer_past[past_present_indices] = kv + layer_past[past_present_indices] = qkv[:, 1:] # output - attn_output = torch.empty_like(query) + attn_output = torch.empty_like(qkv[:, 0]) # flash attention flash_attn_cuda.fwd( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], attn_output, start_seq, end_seq, @@ -176,16 +173,17 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: + query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = kv + layer_past[past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + layer_past[:, 0], + layer_past[:, 1], attn_output, start_seq_q, end_seq_q, @@ -386,7 +384,7 @@ class FlashLlamaModel(torch.nn.Module): start_seq_q, end_seq_q, max_s, - torch.select(past_key_values, dim=1, index=i), + past_key_values[:, i], past_present_indices, prefill, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 55541e45..21362b22 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -125,25 +125,22 @@ class FlashNeoxAttention(torch.nn.Module): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - query, kv = qkv.split([1, 2], dim=1) - query = query.view(-1, self.num_heads, self.head_size) - # Inplace rotary - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(qkv[:, 0], cos, sin) + self.rotary_emb(qkv[:, 1], cos, sin) # Prefill if prefill: # Copy to layer past - layer_past[past_present_indices] = kv + layer_past[past_present_indices] = qkv[:, 1:] # output - attn_output = torch.empty_like(query) + attn_output = torch.empty_like(qkv[:, 0]) # flash attention flash_attn_cuda.fwd( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], attn_output, start_seq, end_seq, @@ -161,16 +158,17 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: + query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = kv + layer_past[past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + layer_past[:, 0], + layer_past[:, 1], attn_output, start_seq_q, end_seq_q, @@ -395,7 +393,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): start_seq_q, end_seq_q, max_s, - torch.select(past_key_values, dim=1, index=i), + past_key_values[:, i], past_present_indices, prefill, )