feat(server): use torch.select to decrease cpu bottleneck

This commit is contained in:
OlivierDehaene 2023-06-01 19:16:37 +02:00
parent 95d3546976
commit 4cd0a9f0c8
5 changed files with 56 additions and 48 deletions

View File

@ -134,23 +134,25 @@ class FlashLlamaAttention(torch.nn.Module):
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
q, kv = qkv.split([1, 2], dim=1)
q = q.squeeze(1)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
self.rotary_emb(q, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
layer_past[...] = kv
# output
attn_output = torch.empty_like(qkv[:, 0])
attn_output = torch.empty_like(q)
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
q,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
@ -166,17 +168,16 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:]
layer_past[layer_past_present_indices] = kv
# output
attn_output = torch.empty_like(query)
attn_output = torch.empty_like(q)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
q,
torch.select(layer_past, dim=1, index=0),
torch.select(layer_past, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
@ -237,7 +238,10 @@ class LlamaMLP(nn.Module):
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(torch.select(gate_up_states, dim=1, index=0))
* torch.select(gate_up_states, dim=1, index=1)
)
class FlashLlamaLayer(nn.Module):

View File

@ -101,23 +101,25 @@ class FlashNeoxAttention(torch.nn.Module):
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
q, kv = qkv.split([1, 2], dim=1)
q = q.squeeze(1)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
self.rotary_emb(q, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
layer_past[...] = kv
# output
attn_output = torch.empty_like(qkv[:, 0])
attn_output = torch.empty_like(q)
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
q,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
@ -133,17 +135,16 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:]
layer_past[layer_past_present_indices] = kv
# output
attn_output = torch.empty_like(query)
attn_output = torch.empty_like(q)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
q,
torch.select(layer_past, dim=1, index=0),
torch.select(layer_past, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,

View File

@ -149,7 +149,7 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill
if layer_past_present_indices is None:
@ -163,8 +163,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
@ -190,8 +190,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
@ -288,7 +288,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, :, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
# Prefill
if layer_past_present_indices is None:
@ -306,8 +306,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, :, 0],
kv[:, :, 1],
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
@ -337,8 +337,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, :, 0],
kv[:, :, 1],
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,

View File

@ -74,8 +74,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
key_value[:, 0],
key_value[:, 1],
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
@ -101,8 +101,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
key_value[:, 0],
key_value[:, 1],
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,

View File

@ -112,7 +112,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input)
# Position ids
position_ids.append(np.arange(0, input_length))
position_ids.append(np.arange(0, input_length, dtype=np.int32))
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
@ -141,16 +141,19 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = np.concatenate(position_ids, dtype=np.int32)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
return cls(