add falcon, santacoder and neox support

This commit is contained in:
OlivierDehaene 2023-06-30 13:19:44 +02:00
parent ddfc02f2a4
commit 16f796f735
13 changed files with 382 additions and 476 deletions

View File

@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
/opt/conda/bin/conda clean -ya
# Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder
@ -109,6 +108,16 @@ COPY server/custom_kernels/ .
# Build specific version of transformers
RUN python setup.py build
# Build vllm CUDA kernels
FROM kernel-builder as vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm
# Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from transformers builder
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

View File

@ -290,7 +290,6 @@ async fn batching_task(
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
tracing::info!("{token_budget} {batch_max_tokens}");
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue

View File

@ -180,7 +180,11 @@ fn main() -> Result<(), std::io::Error> {
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
// Get info from the shard
let shard_info = sharded_client
.info()

View File

@ -305,7 +305,7 @@ mod tests {
watermark: false,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: true,
ignore_eos_token: false,
max_new_tokens: 1,
stop_sequences: vec![],
},

View File

@ -152,7 +152,7 @@ async fn generate(
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
// tracing::debug!("Input: {}", req.0.inputs);
tracing::debug!("Input: {}", req.0.inputs);
let compute_characters = req.0.inputs.chars().count();
let mut add_prompt = None;
@ -286,7 +286,7 @@ async fn generate(
}
tracing::debug!("Output: {}", output_text);
// tracing::info!("Success");
tracing::info!("Success");
let response = GenerateResponse {
generated_text: output_text,

13
server/Makefile-vllm Normal file
View File

@ -0,0 +1,13 @@
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
vllm:
# Clone vllm
git clone https://github.com/OlivierDehaene/vllm.git
build-vllm: vllm
cd vllm && git fetch && git checkout $(flash_att_commit)
cd vllm && python setup.py build
install-vllm: build-vllm
pip uninstall vllm -y || true
cd vllm && python setup.py install

View File

@ -24,13 +24,15 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from vllm import attention_ops
from vllm import cache_ops
# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -124,6 +126,9 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights,
bias=False,
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
@ -145,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
cache_ops.reshape_and_cache(
vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
@ -178,11 +183,12 @@ class FlashLlamaAttention(torch.nn.Module):
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
attention_ops.single_query_cached_kv_attention(
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
qkv[:, 0],
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,

View File

@ -25,11 +25,15 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module):
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(qkv[:, 0])
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
qkv[:, 0],
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module):
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.gpt_neox(
) -> torch.Tensor:
hidden_states = self.gpt_neox(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states)
return logits, present
return logits

View File

@ -4,11 +4,15 @@ import torch.distributed
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
if self.num_heads_kv == 1:
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
else:
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
vllm_cache_ops.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill
if start_seq_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv
# Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
self.kv_head_mapping = torch.arange(
0, self.num_groups, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_heads)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
)
# output
attn_output = torch.empty_like(query)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = kv
if start_seq_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module):
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv
# Expand to query shape
kv = (
layer_past.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(
@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module):
ln_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
mlp_output = self.mlp(ln_hidden_states)
@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module):
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module):
ln_attn,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
# MLP.
@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
2,
self.h[0].self_attention.num_heads_kv,
self.h[0].self_attention.head_size,
)
self.cache_size = self.h[0].self_attention.num_heads_kv
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
self.h[0].self_attention.num_groups,
2,
self.h[0].self_attention.head_size,
)
self.cache_size = self.h[0].self_attention.num_groups
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.h),
*self.cache_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashRWForCausalLM(FlashRWPreTrainedModel):
@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits

View File

@ -3,11 +3,15 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.c_attn(hidden_states)
@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
vllm_cache_ops.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = key_value
if start_seq_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -361,27 +357,25 @@ class Block(nn.Module):
self,
hidden_states,
residual,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_zeros(
(len(input_ids), len(self.h), 2, 1, self.head_size)
)
# Decode
else:
prefill = False
residual = None
for i, layer in enumerate(self.h):
hidden_states, residual = layer(
hidden_states,
residual,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashSantacoderForCausalLM(nn.Module):
@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits

View File

@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM):
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashNeoXSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype,
device=device,
rank=rank,

View File

@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM):
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashRWSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,

View File

@ -62,10 +62,12 @@ class FlashSantacoderSharded(FlashCausalLM):
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashSantacoderSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,