add falcon, santacoder and neox support

This commit is contained in:
OlivierDehaene 2023-06-30 13:19:44 +02:00
parent ddfc02f2a4
commit 16f796f735
13 changed files with 382 additions and 476 deletions

View File

@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \ RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# Build Flash Attention CUDA kernels # Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder FROM kernel-builder as flash-att-builder
@ -109,6 +108,16 @@ COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN python setup.py build RUN python setup.py build
# Build vllm CUDA kernels
FROM kernel-builder as vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from transformers builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir

View File

@ -290,7 +290,6 @@ async fn batching_task(
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
tracing::info!("{token_budget} {batch_max_tokens}");
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue

View File

@ -180,7 +180,11 @@ fn main() -> Result<(), std::io::Error> {
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
.expect("Could not connect to server"); .expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
// Get info from the shard // Get info from the shard
let shard_info = sharded_client let shard_info = sharded_client
.info() .info()

View File

@ -305,7 +305,7 @@ mod tests {
watermark: false, watermark: false,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: true, ignore_eos_token: false,
max_new_tokens: 1, max_new_tokens: 1,
stop_sequences: vec![], stop_sequences: vec![],
}, },

View File

@ -152,7 +152,7 @@ async fn generate(
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
// tracing::debug!("Input: {}", req.0.inputs); tracing::debug!("Input: {}", req.0.inputs);
let compute_characters = req.0.inputs.chars().count(); let compute_characters = req.0.inputs.chars().count();
let mut add_prompt = None; let mut add_prompt = None;
@ -286,7 +286,7 @@ async fn generate(
} }
tracing::debug!("Output: {}", output_text); tracing::debug!("Output: {}", output_text);
// tracing::info!("Success"); tracing::info!("Success");
let response = GenerateResponse { let response = GenerateResponse {
generated_text: output_text, generated_text: output_text,

13
server/Makefile-vllm Normal file
View File

@ -0,0 +1,13 @@
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
vllm:
# Clone vllm
git clone https://github.com/OlivierDehaene/vllm.git
build-vllm: vllm
cd vllm && git fetch && git checkout $(flash_att_commit)
cd vllm && python setup.py build
install-vllm: build-vllm
pip uninstall vllm -y || true
cd vllm && python setup.py install

View File

@ -24,13 +24,15 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from vllm import attention_ops
from vllm import cache_ops
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
import dropout_layer_norm import dropout_layer_norm
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -124,6 +126,9 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
@ -145,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
cache_ops.reshape_and_cache( vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
) )
@ -178,11 +183,12 @@ class FlashLlamaAttention(torch.nn.Module):
else: else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size] # kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3] block_size = kv_cache[1].shape[3]
attention_ops.single_query_cached_kv_attention( vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -25,11 +25,15 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module):
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True
) )
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill vllm_cache_ops.reshape_and_cache(
if prefill: qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
# Copy to layer past )
layer_past[...] = qkv[:, 1:]
# output # output tensor
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0] # kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
# Add present to the layer_past tensor at the correct indices block_size = kv_cache[1].shape[3]
layer_past[past_present_indices] = qkv[:, 1:] vllm_attention_ops.single_query_cached_kv_attention(
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
attn_output, attn_output,
start_seq_q, qkv[:, 0],
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states) ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states, ln1_hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, start_seq_prefill: Optional[torch.Tensor],
end_seq, end_seq_prefill: Optional[torch.Tensor],
start_seq_q, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
end_seq_q, block_tables: torch.Tensor,
max_s, slots: torch.Tensor,
past_present_indices, input_lengths: torch.Tensor,
past_key_values=None, max_s: int,
pre_allocate_past_size: Optional[int] = None, ) -> torch.Tensor:
):
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual, residual,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache[i],
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
past_key_values[:, i],
past_present_indices,
prefill,
) )
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states, past_key_values return hidden_states
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, start_seq_prefill: Optional[torch.Tensor],
end_seq, end_seq_prefill: Optional[torch.Tensor],
start_seq_q, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
end_seq_q, block_tables: torch.Tensor,
max_s, slots: torch.Tensor,
past_present_indices, input_lengths: torch.Tensor,
past_key_values: Optional[torch.Tensor] = None, max_s: int,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
hidden_states, present = self.gpt_neox( hidden_states = self.gpt_neox(
input_ids, input_ids,
position_ids, position_ids,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states) logits = self.embed_out(hidden_states)
return logits, present return logits

View File

@ -4,11 +4,15 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
if self.num_heads_kv == 1:
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
else:
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill vllm_cache_ops.reshape_and_cache(
if prefill: kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
# Copy to layer past )
layer_past[...] = kv
# Expand to query shape # output
kv = kv.expand(-1, 2, self.num_heads, self.head_size) attn_output = torch.empty_like(query)
# Prefill
if start_seq_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output, attn_output,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
layer_past[past_present_indices] = kv block_size = kv_cache[1].shape[3]
# Expand to query shape vllm_attention_ops.single_query_cached_kv_attention(
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output, attn_output,
start_seq_q, query,
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
self.kv_head_mapping = torch.arange(
0, self.num_groups, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_heads)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
)
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if prefill: if start_seq_prefill is not None:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape # Expand to query shape
kv = ( kv = (
kv.unsqueeze(2) kv.unsqueeze(2)
@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module):
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
) )
# output
attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
attn_output, attn_output,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
layer_past[past_present_indices] = kv block_size = kv_cache[1].shape[3]
# Expand to query shape vllm_attention_ops.single_query_cached_kv_attention(
kv = (
layer_past.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output, attn_output,
start_seq_q, query,
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.dense( return self.dense(
@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
if self.parallel_attn: if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module):
ln_hidden_states, ln_hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
ln_attn, residual = self.ln_attn(hidden_states, residual) ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual) ln_mlp, _ = self.ln_mlp(residual)
@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module):
ln_attn, ln_attn,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
# MLP. # MLP.
@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = self.h[0].self_attention.num_heads_kv
2,
self.h[0].self_attention.num_heads_kv,
self.h[0].self_attention.head_size,
)
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = self.h[0].self_attention.num_groups
self.h[0].self_attention.num_groups,
2,
self.h[0].self_attention.head_size,
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"model_type {config.model_type} is not supported." f"model_type {config.model_type} is not supported."
@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, start_seq_prefill: Optional[torch.Tensor],
end_seq, end_seq_prefill: Optional[torch.Tensor],
start_seq_q, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
end_seq_q, block_tables: torch.Tensor,
max_s, slots: torch.Tensor,
past_present_indices, input_lengths: torch.Tensor,
past_key_values=None, max_s: int,
pre_allocate_past_size: Optional[int] = None, ) -> torch.Tensor:
):
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.h),
*self.cache_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual, residual,
cos, cos,
sin, sin,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache[i],
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
) )
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values return hidden_states
class FlashRWForCausalLM(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel):
@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, start_seq_prefill: Optional[torch.Tensor],
end_seq, end_seq_prefill: Optional[torch.Tensor],
start_seq_q, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
end_seq_q, block_tables: torch.Tensor,
max_s, slots: torch.Tensor,
past_present_indices, input_lengths: torch.Tensor,
past_key_values: Optional[torch.Tensor] = None, max_s: int,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
hidden_states, present = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits, present return logits

View File

@ -3,11 +3,15 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
vllm_cache_ops.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if prefill: if start_seq_prefill is not None:
# Copy to layer past
layer_past[...] = key_value
# Expand from 1 to num_heads # Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # kv_cache[1] => [num_blocks, 1, head_size, block_size]
layer_past[past_present_indices] = key_value block_size = kv_cache[1].shape[3]
# Expand from 1 to num_heads vllm_attention_ops.single_query_cached_kv_attention(
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
start_seq_q, query,
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -361,27 +357,25 @@ class Block(nn.Module):
self, self,
hidden_states, hidden_states,
residual, residual,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.attn(
hidden_states, hidden_states,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states, residual = self.ln_2(hidden_states, residual)
@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, start_seq_prefill: Optional[torch.Tensor],
end_seq, end_seq_prefill: Optional[torch.Tensor],
start_seq_q, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
end_seq_q, block_tables: torch.Tensor,
max_s, slots: torch.Tensor,
past_present_indices, input_lengths: torch.Tensor,
past_key_values=None, max_s: int,
pre_allocate_past_size: Optional[int] = None, ) -> torch.Tensor:
):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_zeros(
(len(input_ids), len(self.h), 2, 1, self.head_size)
)
# Decode
else:
prefill = False
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache[i],
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
) )
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values return hidden_states
class FlashSantacoderForCausalLM(nn.Module): class FlashSantacoderForCausalLM(nn.Module):
@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, start_seq_prefill: Optional[torch.Tensor],
end_seq, end_seq_prefill: Optional[torch.Tensor],
start_seq_q, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
end_seq_q, block_tables: torch.Tensor,
max_s, slots: torch.Tensor,
past_present_indices, input_lengths: torch.Tensor,
past_key_values: Optional[torch.Tensor] = None, max_s: int,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
hidden_states, present = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
start_seq, start_seq_prefill,
end_seq, end_seq_prefill,
start_seq_q, kv_cache,
end_seq_q, block_tables,
slots,
input_lengths,
max_s, max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits, present return logits

View File

@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM):
model = FlashGPTNeoXForCausalLM(config, weights) model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashNeoXSharded, self).__init__(
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,

View File

@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM):
model = FlashRWForCausalLM(config, weights) model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashRWSharded, self).__init__(
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,

View File

@ -62,10 +62,12 @@ class FlashSantacoderSharded(FlashCausalLM):
model = FlashSantacoderForCausalLM(config, weights) model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashSantacoderSharded, self).__init__(
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,