mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
feat(server): use torch.select to decrease cpu bottleneck
This commit is contained in:
parent
95d3546976
commit
4cd0a9f0c8
@ -134,23 +134,25 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
q, kv = qkv.split([1, 2], dim=1)
|
||||||
|
q = q.squeeze(1)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(q, cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if layer_past_present_indices is None:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = qkv[:, 1:]
|
layer_past[...] = kv
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(q)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
q,
|
||||||
qkv[:, 1],
|
torch.select(kv, dim=1, index=0),
|
||||||
qkv[:, 2],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -166,17 +168,16 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
layer_past[layer_past_present_indices] = kv
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(q)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
q,
|
||||||
layer_past[:, 0],
|
torch.select(layer_past, dim=1, index=0),
|
||||||
layer_past[:, 1],
|
torch.select(layer_past, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -237,7 +238,10 @@ class LlamaMLP(nn.Module):
|
|||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(torch.select(gate_up_states, dim=1, index=0))
|
||||||
|
* torch.select(gate_up_states, dim=1, index=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
|
@ -101,23 +101,25 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
q, kv = qkv.split([1, 2], dim=1)
|
||||||
|
q = q.squeeze(1)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(q, cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=1), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if layer_past_present_indices is None:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = qkv[:, 1:]
|
layer_past[...] = kv
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(q)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
q,
|
||||||
qkv[:, 1],
|
torch.select(kv, dim=1, index=0),
|
||||||
qkv[:, 2],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -133,17 +135,16 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
layer_past[layer_past_present_indices] = kv
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(q)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
q,
|
||||||
layer_past[:, 0],
|
torch.select(layer_past, dim=1, index=0),
|
||||||
layer_past[:, 1],
|
torch.select(layer_past, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
@ -149,7 +149,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(kv[:, 0], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if layer_past_present_indices is None:
|
||||||
@ -163,8 +163,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, 0],
|
torch.select(kv, dim=1, index=0),
|
||||||
kv[:, 1],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -190,8 +190,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, 0],
|
torch.select(kv, dim=1, index=0),
|
||||||
kv[:, 1],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -288,7 +288,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(kv[:, :, 0], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if layer_past_present_indices is None:
|
||||||
@ -306,8 +306,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, :, 0],
|
torch.select(kv, dim=2, index=0),
|
||||||
kv[:, :, 1],
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -337,8 +337,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, :, 0],
|
torch.select(kv, dim=2, index=0),
|
||||||
kv[:, :, 1],
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
@ -74,8 +74,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
key_value[:, 0],
|
torch.select(key_value, dim=1, index=0),
|
||||||
key_value[:, 1],
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
@ -101,8 +101,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
key_value[:, 0],
|
torch.select(key_value, dim=1, index=0),
|
||||||
key_value[:, 1],
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
@ -112,7 +112,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
position_ids.append(np.arange(0, input_length))
|
position_ids.append(np.arange(0, input_length, dtype=np.int32))
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlens.append(cumulative_length + input_length)
|
cu_seqlens.append(cumulative_length + input_length)
|
||||||
@ -141,16 +141,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for i, input_ids in enumerate(all_input_ids):
|
for i, input_ids in enumerate(all_input_ids):
|
||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
|
if len(pb.requests) > 1:
|
||||||
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
|
position_ids = np.concatenate(position_ids, dtype=np.int32)
|
||||||
|
else:
|
||||||
|
input_ids = all_input_ids[0]
|
||||||
|
position_ids = position_ids[0]
|
||||||
|
|
||||||
# Create tensors on device
|
# Create tensors on device
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
np.concatenate(all_input_ids), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
all_input_ids_tensor = torch.tensor(
|
all_input_ids_tensor = torch.tensor(
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
position_ids = torch.tensor(
|
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||||
np.concatenate(position_ids), dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
Loading…
Reference in New Issue
Block a user