feat(server): use torch.select to decrease cpu bottleneck

This commit is contained in:
OlivierDehaene 2023-06-01 19:16:37 +02:00
parent 95d3546976
commit 4cd0a9f0c8
5 changed files with 56 additions and 48 deletions

View File

@ -134,23 +134,25 @@ class FlashLlamaAttention(torch.nn.Module):
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
q, kv = qkv.split([1, 2], dim=1)
q = q.squeeze(1)
# Inplace rotary # Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(q, cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv[:, 1:] layer_past[...] = kv
# output # output
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(q)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv[:, 0], q,
qkv[:, 1], torch.select(kv, dim=1, index=0),
qkv[:, 2], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
@ -166,17 +168,16 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:] layer_past[layer_past_present_indices] = kv
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(q)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, q,
layer_past[:, 0], torch.select(layer_past, dim=1, index=0),
layer_past[:, 1], torch.select(layer_past, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens, cu_seqlens,
@ -237,7 +238,10 @@ class LlamaMLP(nn.Module):
def forward(self, hidden_states): def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(torch.select(gate_up_states, dim=1, index=0))
* torch.select(gate_up_states, dim=1, index=1)
)
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):

View File

@ -101,23 +101,25 @@ class FlashNeoxAttention(torch.nn.Module):
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
q, kv = qkv.split([1, 2], dim=1)
q = q.squeeze(1)
# Inplace rotary # Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(q, cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv[:, 1:] layer_past[...] = kv
# output # output
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(q)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv[:, 0], q,
qkv[:, 1], torch.select(kv, dim=1, index=0),
qkv[:, 2], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
@ -133,17 +135,16 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:] layer_past[layer_past_present_indices] = kv
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(q)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, q,
layer_past[:, 0], torch.select(layer_past, dim=1, index=0),
layer_past[:, 1], torch.select(layer_past, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens, cu_seqlens,

View File

@ -149,7 +149,7 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
@ -163,8 +163,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, 0], torch.select(kv, dim=1, index=0),
kv[:, 1], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
@ -190,8 +190,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, 0], torch.select(kv, dim=1, index=0),
kv[:, 1], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens, cu_seqlens,
@ -288,7 +288,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, :, 0], cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
@ -306,8 +306,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, :, 0], torch.select(kv, dim=2, index=0),
kv[:, :, 1], torch.select(kv, dim=2, index=1),
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
@ -337,8 +337,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, :, 0], torch.select(kv, dim=2, index=0),
kv[:, :, 1], torch.select(kv, dim=2, index=1),
attn_output, attn_output,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens, cu_seqlens,

View File

@ -74,8 +74,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
key_value[:, 0], torch.select(key_value, dim=1, index=0),
key_value[:, 1], torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
@ -101,8 +101,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
key_value[:, 0], torch.select(key_value, dim=1, index=0),
key_value[:, 1], torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens, cu_seqlens,

View File

@ -112,7 +112,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
position_ids.append(np.arange(0, input_length)) position_ids.append(np.arange(0, input_length, dtype=np.int32))
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length) cu_seqlens.append(cumulative_length + input_length)
@ -141,16 +141,19 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids): for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids all_input_ids_tensor[i, : len(input_ids)] = input_ids
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = np.concatenate(position_ids, dtype=np.int32)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
# Create tensors on device # Create tensors on device
input_ids = torch.tensor( input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
all_input_ids_tensor = torch.tensor( all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device all_input_ids_tensor, dtype=torch.int64, device=device
) )
position_ids = torch.tensor( position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
np.concatenate(position_ids), dtype=torch.int32, device=device
)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
return cls( return cls(