diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2dcb6ed8..152af74d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -134,23 +134,25 @@ class FlashLlamaAttention(torch.nn.Module): ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + q, kv = qkv.split([1, 2], dim=1) + q = q.squeeze(1) # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(q, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin) # Prefill if layer_past_present_indices is None: # Copy to layer past - layer_past[...] = qkv[:, 1:] + layer_past[...] = kv # output - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(q) # flash attention flash_attn_cuda.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + q, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, cu_seqlens, cu_seqlens, @@ -166,17 +168,16 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv[:, 1:] + layer_past[layer_past_present_indices] = kv # output - attn_output = torch.empty_like(query) + attn_output = torch.empty_like(q) # flash attention flash_attn_cuda.fwd( - query, - layer_past[:, 0], - layer_past[:, 1], + q, + torch.select(layer_past, dim=1, index=0), + torch.select(layer_past, dim=1, index=1), attn_output, cu_seqlens_q, cu_seqlens, @@ -237,7 +238,10 @@ class LlamaMLP(nn.Module): def forward(self, hidden_states): gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(torch.select(gate_up_states, dim=1, index=0)) + * torch.select(gate_up_states, dim=1, index=1) + ) class FlashLlamaLayer(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 26e21753..17b5012c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -101,23 +101,25 @@ class FlashNeoxAttention(torch.nn.Module): ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + q, kv = qkv.split([1, 2], dim=1) + q = q.squeeze(1) # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(q, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin) # Prefill if layer_past_present_indices is None: # Copy to layer past - layer_past[...] = qkv[:, 1:] + layer_past[...] = kv # output - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(q) # flash attention flash_attn_cuda.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + q, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, cu_seqlens, cu_seqlens, @@ -133,17 +135,16 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv[:, 1:] + layer_past[layer_past_present_indices] = kv # output - attn_output = torch.empty_like(query) + attn_output = torch.empty_like(q) # flash attention flash_attn_cuda.fwd( - query, - layer_past[:, 0], - layer_past[:, 1], + q, + torch.select(layer_past, dim=1, index=0), + torch.select(layer_past, dim=1, index=1), attn_output, cu_seqlens_q, cu_seqlens, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 545da26a..d169641b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -149,7 +149,7 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, 0], cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) # Prefill if layer_past_present_indices is None: @@ -163,8 +163,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, 0], - kv[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, cu_seqlens, cu_seqlens, @@ -190,8 +190,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, 0], - kv[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, cu_seqlens_q, cu_seqlens, @@ -288,7 +288,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, :, 0], cos, sin) + self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) # Prefill if layer_past_present_indices is None: @@ -306,8 +306,8 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, :, 0], - kv[:, :, 1], + torch.select(kv, dim=2, index=0), + torch.select(kv, dim=2, index=1), attn_output, cu_seqlens, cu_seqlens, @@ -337,8 +337,8 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, :, 0], - kv[:, :, 1], + torch.select(kv, dim=2, index=0), + torch.select(kv, dim=2, index=1), attn_output, cu_seqlens_q, cu_seqlens, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9bded805..13ddcb3f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -74,8 +74,8 @@ class FlashMQAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - key_value[:, 0], - key_value[:, 1], + torch.select(key_value, dim=1, index=0), + torch.select(key_value, dim=1, index=1), attn_output, cu_seqlens, cu_seqlens, @@ -101,8 +101,8 @@ class FlashMQAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - key_value[:, 0], - key_value[:, 1], + torch.select(key_value, dim=1, index=0), + torch.select(key_value, dim=1, index=1), attn_output, cu_seqlens_q, cu_seqlens, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 35cbe174..b909a2fb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -112,7 +112,7 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(tokenized_input) # Position ids - position_ids.append(np.arange(0, input_length)) + position_ids.append(np.arange(0, input_length, dtype=np.int32)) # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) @@ -141,16 +141,19 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + position_ids = np.concatenate(position_ids, dtype=np.int32) + else: + input_ids = all_input_ids[0] + position_ids = position_ids[0] + # Create tensors on device - input_ids = torch.tensor( - np.concatenate(all_input_ids), dtype=torch.int64, device=device - ) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) - position_ids = torch.tensor( - np.concatenate(position_ids), dtype=torch.int32, device=device - ) + position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) return cls(