mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
feat(server): use torch.select to decrease cpu bottleneck
This commit is contained in:
parent
95d3546976
commit
4cd0a9f0c8
@ -134,23 +134,25 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
q, kv = qkv.split([1, 2], dim=1)
|
||||
q = q.squeeze(1)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
self.rotary_emb(q, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = qkv[:, 1:]
|
||||
layer_past[...] = kv
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
attn_output = torch.empty_like(q)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
q,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
@ -166,17 +168,16 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
query = qkv[:, 0]
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
attn_output = torch.empty_like(q)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
layer_past[:, 0],
|
||||
layer_past[:, 1],
|
||||
q,
|
||||
torch.select(layer_past, dim=1, index=0),
|
||||
torch.select(layer_past, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
@ -237,7 +238,10 @@ class LlamaMLP(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(torch.select(gate_up_states, dim=1, index=0))
|
||||
* torch.select(gate_up_states, dim=1, index=1)
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
|
@ -101,23 +101,25 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
q, kv = qkv.split([1, 2], dim=1)
|
||||
q = q.squeeze(1)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
self.rotary_emb(q, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = qkv[:, 1:]
|
||||
layer_past[...] = kv
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
attn_output = torch.empty_like(q)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
q,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
@ -133,17 +135,16 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
query = qkv[:, 0]
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
attn_output = torch.empty_like(q)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
layer_past[:, 0],
|
||||
layer_past[:, 1],
|
||||
q,
|
||||
torch.select(layer_past, dim=1, index=0),
|
||||
torch.select(layer_past, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
|
@ -149,7 +149,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(kv[:, 0], cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
@ -163,8 +163,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
@ -190,8 +190,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
@ -288,7 +288,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(kv[:, :, 0], cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
@ -306,8 +306,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
@ -337,8 +337,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
|
@ -74,8 +74,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
key_value[:, 0],
|
||||
key_value[:, 1],
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
@ -101,8 +101,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
key_value[:, 0],
|
||||
key_value[:, 1],
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
|
@ -112,7 +112,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
position_ids.append(np.arange(0, input_length))
|
||||
position_ids.append(np.arange(0, input_length, dtype=np.int32))
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.append(cumulative_length + input_length)
|
||||
@ -141,16 +141,19 @@ class FlashCausalLMBatch(Batch):
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = np.concatenate(position_ids, dtype=np.int32)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(
|
||||
np.concatenate(all_input_ids), dtype=torch.int64, device=device
|
||||
)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
position_ids = torch.tensor(
|
||||
np.concatenate(position_ids), dtype=torch.int32, device=device
|
||||
)
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
||||
|
||||
return cls(
|
||||
|
Loading…
Reference in New Issue
Block a user