mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add falcon, santacoder and neox support
This commit is contained in:
parent
ddfc02f2a4
commit
16f796f735
16
Dockerfile
16
Dockerfile
@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
|
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
|
||||||
/opt/conda/bin/conda clean -ya
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
|
|
||||||
# Build Flash Attention CUDA kernels
|
# Build Flash Attention CUDA kernels
|
||||||
FROM kernel-builder as flash-att-builder
|
FROM kernel-builder as flash-att-builder
|
||||||
|
|
||||||
@ -109,6 +108,16 @@ COPY server/custom_kernels/ .
|
|||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
# Build vllm CUDA kernels
|
||||||
|
FROM kernel-builder as vllm-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-vllm Makefile
|
||||||
|
|
||||||
|
# Build specific version of vllm
|
||||||
|
RUN make build-vllm
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
|
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
|
||||||
|
|
||||||
@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
|
|||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from transformers builder
|
# Copy build artifacts from custom kernels builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
||||||
|
|
||||||
|
# Copy builds artifacts from vllm builder
|
||||||
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
|
@ -290,7 +290,6 @@ async fn batching_task(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
tracing::info!("{token_budget} {batch_max_tokens}");
|
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
@ -180,7 +180,11 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
.expect("Could not connect to server");
|
.expect("Could not connect to server");
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.expect("Unable to clear cache");
|
||||||
// Get info from the shard
|
// Get info from the shard
|
||||||
let shard_info = sharded_client
|
let shard_info = sharded_client
|
||||||
.info()
|
.info()
|
||||||
|
@ -305,7 +305,7 @@ mod tests {
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
|
@ -152,7 +152,7 @@ async fn generate(
|
|||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
// tracing::debug!("Input: {}", req.0.inputs);
|
tracing::debug!("Input: {}", req.0.inputs);
|
||||||
|
|
||||||
let compute_characters = req.0.inputs.chars().count();
|
let compute_characters = req.0.inputs.chars().count();
|
||||||
let mut add_prompt = None;
|
let mut add_prompt = None;
|
||||||
@ -286,7 +286,7 @@ async fn generate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
tracing::debug!("Output: {}", output_text);
|
tracing::debug!("Output: {}", output_text);
|
||||||
// tracing::info!("Success");
|
tracing::info!("Success");
|
||||||
|
|
||||||
let response = GenerateResponse {
|
let response = GenerateResponse {
|
||||||
generated_text: output_text,
|
generated_text: output_text,
|
||||||
|
13
server/Makefile-vllm
Normal file
13
server/Makefile-vllm
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
|
||||||
|
|
||||||
|
vllm:
|
||||||
|
# Clone vllm
|
||||||
|
git clone https://github.com/OlivierDehaene/vllm.git
|
||||||
|
|
||||||
|
build-vllm: vllm
|
||||||
|
cd vllm && git fetch && git checkout $(flash_att_commit)
|
||||||
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
|
install-vllm: build-vllm
|
||||||
|
pip uninstall vllm -y || true
|
||||||
|
cd vllm && python setup.py install
|
@ -24,13 +24,15 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from vllm import attention_ops
|
|
||||||
from vllm import cache_ops
|
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -124,6 +126,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -145,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(
|
vllm_cache_ops.reshape_and_cache(
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -178,11 +183,12 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||||
block_size = kv_cache[1].shape[3]
|
block_size = kv_cache[1].shape[3]
|
||||||
attention_ops.single_query_cached_kv_attention(
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
|
@ -25,11 +25,15 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
vllm_cache_ops.reshape_and_cache(
|
||||||
if prefill:
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
# Copy to layer past
|
)
|
||||||
layer_past[...] = qkv[:, 1:]
|
|
||||||
|
|
||||||
# output
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if start_seq_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
block_size = kv_cache[1].shape[3]
|
||||||
layer_past[past_present_indices] = qkv[:, 1:]
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
layer_past[:, 0],
|
|
||||||
layer_past[:, 1],
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
qkv[:, 0],
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
ln1_hidden_states,
|
ln1_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
end_seq_q,
|
block_tables: torch.Tensor,
|
||||||
max_s,
|
slots: torch.Tensor,
|
||||||
past_present_indices,
|
input_lengths: torch.Tensor,
|
||||||
past_key_values=None,
|
max_s: int,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
) -> torch.Tensor:
|
||||||
):
|
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if past_key_values is None:
|
|
||||||
assert pre_allocate_past_size is not None
|
|
||||||
|
|
||||||
prefill = True
|
|
||||||
|
|
||||||
# Create past tensor
|
|
||||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
len(input_ids),
|
|
||||||
len(self.layers),
|
|
||||||
2,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Decode
|
|
||||||
else:
|
|
||||||
prefill = False
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
|
||||||
@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache[i],
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[:, i],
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
|
||||||
present = past_key_values
|
|
||||||
# Create padded past tensor
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
pre_allocate_past_size,
|
|
||||||
len(self.layers),
|
|
||||||
2,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# We slice only once instead of at every layer
|
|
||||||
past_key_values[past_present_indices] = present
|
|
||||||
|
|
||||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
end_seq_q,
|
block_tables: torch.Tensor,
|
||||||
max_s,
|
slots: torch.Tensor,
|
||||||
past_present_indices,
|
input_lengths: torch.Tensor,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
max_s: int,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
hidden_states, present = self.gpt_neox(
|
hidden_states = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
|
||||||
past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.embed_out(hidden_states)
|
logits = self.embed_out(hidden_states)
|
||||||
return logits, present
|
return logits
|
||||||
|
@ -4,11 +4,15 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.num_heads_kv == 1:
|
||||||
|
self.kv_head_mapping = torch.zeros(
|
||||||
|
self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
vllm_cache_ops.reshape_and_cache(
|
||||||
if prefill:
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
# Copy to layer past
|
)
|
||||||
layer_past[...] = kv
|
|
||||||
# Expand to query shape
|
# output
|
||||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if start_seq_prefill is not None:
|
||||||
|
if self.num_heads_kv == 1:
|
||||||
|
# Expand to query shape
|
||||||
|
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
|
||||||
layer_past[past_present_indices] = kv
|
block_size = kv_cache[1].shape[3]
|
||||||
# Expand to query shape
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
torch.select(kv, dim=1, index=0),
|
|
||||||
torch.select(kv, dim=1, index=1),
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
query,
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_groups, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_heads)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
|
vllm_cache_ops.reshape_and_cache(
|
||||||
|
kv[:, :, 0].contiguous(),
|
||||||
|
kv[:, :, 1].contiguous(),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
slots,
|
||||||
|
)
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if start_seq_prefill is not None:
|
||||||
# Copy to layer past
|
|
||||||
layer_past[...] = kv
|
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = (
|
kv = (
|
||||||
kv.unsqueeze(2)
|
kv.unsqueeze(2)
|
||||||
@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
|
||||||
layer_past[past_present_indices] = kv
|
block_size = kv_cache[1].shape[3]
|
||||||
# Expand to query shape
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
kv = (
|
|
||||||
layer_past.unsqueeze(2)
|
|
||||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
|
||||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
torch.select(kv, dim=2, index=0),
|
|
||||||
torch.select(kv, dim=2, index=1),
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
query,
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(
|
return self.dense(
|
||||||
@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module):
|
|||||||
ln_hidden_states,
|
ln_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(ln_hidden_states)
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
ln_mlp, _ = self.ln_mlp(residual)
|
||||||
@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
ln_attn,
|
ln_attn,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# MLP.
|
# MLP.
|
||||||
@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = (
|
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||||
2,
|
|
||||||
self.h[0].self_attention.num_heads_kv,
|
|
||||||
self.h[0].self_attention.head_size,
|
|
||||||
)
|
|
||||||
elif config.model_type == "RefinedWeb":
|
elif config.model_type == "RefinedWeb":
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = (
|
self.cache_size = self.h[0].self_attention.num_groups
|
||||||
self.h[0].self_attention.num_groups,
|
|
||||||
2,
|
|
||||||
self.h[0].self_attention.head_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"model_type {config.model_type} is not supported."
|
f"model_type {config.model_type} is not supported."
|
||||||
@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
end_seq_q,
|
block_tables: torch.Tensor,
|
||||||
max_s,
|
slots: torch.Tensor,
|
||||||
past_present_indices,
|
input_lengths: torch.Tensor,
|
||||||
past_key_values=None,
|
max_s: int,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
) -> torch.Tensor:
|
||||||
):
|
|
||||||
hidden_states = self.word_embeddings(input_ids)
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if past_key_values is None:
|
|
||||||
assert pre_allocate_past_size is not None
|
|
||||||
|
|
||||||
prefill = True
|
|
||||||
|
|
||||||
# Create past tensor
|
|
||||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
len(input_ids),
|
|
||||||
len(self.h),
|
|
||||||
*self.cache_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Decode
|
|
||||||
else:
|
|
||||||
prefill = False
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
|
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
|
||||||
@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache[i],
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
torch.select(past_key_values, dim=1, index=i),
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
|
||||||
present = past_key_values
|
|
||||||
# Create padded past tensor
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
pre_allocate_past_size,
|
|
||||||
len(self.h),
|
|
||||||
*self.cache_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# We slice only once instead of at every layer
|
|
||||||
past_key_values[past_present_indices] = present
|
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
end_seq_q,
|
block_tables: torch.Tensor,
|
||||||
max_s,
|
slots: torch.Tensor,
|
||||||
past_present_indices,
|
input_lengths: torch.Tensor,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
max_s: int,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
hidden_states, present = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
|
||||||
past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits, present
|
return logits
|
||||||
|
@ -3,11 +3,15 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.c_proj = load_row(
|
self.c_proj = load_row(
|
||||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
self.kv_head_mapping = torch.zeros(
|
||||||
|
self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.c_attn(hidden_states)
|
qkv = self.c_attn(hidden_states)
|
||||||
|
|
||||||
@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
|
vllm_cache_ops.reshape_and_cache(
|
||||||
|
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if start_seq_prefill is not None:
|
||||||
# Copy to layer past
|
|
||||||
layer_past[...] = key_value
|
|
||||||
# Expand from 1 to num_heads
|
# Expand from 1 to num_heads
|
||||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
|
||||||
layer_past[past_present_indices] = key_value
|
block_size = kv_cache[1].shape[3]
|
||||||
# Expand from 1 to num_heads
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
torch.select(key_value, dim=1, index=0),
|
|
||||||
torch.select(key_value, dim=1, index=1),
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
query,
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -361,27 +357,25 @@ class Block(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.attn(
|
hidden_states = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||||
@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
end_seq_q,
|
block_tables: torch.Tensor,
|
||||||
max_s,
|
slots: torch.Tensor,
|
||||||
past_present_indices,
|
input_lengths: torch.Tensor,
|
||||||
past_key_values=None,
|
max_s: int,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
) -> torch.Tensor:
|
||||||
):
|
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if past_key_values is None:
|
|
||||||
assert pre_allocate_past_size is not None
|
|
||||||
|
|
||||||
prefill = True
|
|
||||||
|
|
||||||
# Create past tensor
|
|
||||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
|
||||||
past_key_values = hidden_states.new_zeros(
|
|
||||||
(len(input_ids), len(self.h), 2, 1, self.head_size)
|
|
||||||
)
|
|
||||||
# Decode
|
|
||||||
else:
|
|
||||||
prefill = False
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache[i],
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
torch.select(past_key_values, dim=1, index=i),
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
|
||||||
present = past_key_values
|
|
||||||
# Create padded past tensor
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
|
||||||
)
|
|
||||||
# We slice only once instead of at every layer
|
|
||||||
past_key_values[past_present_indices] = present
|
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashSantacoderForCausalLM(nn.Module):
|
class FlashSantacoderForCausalLM(nn.Module):
|
||||||
@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
start_seq_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
end_seq_prefill: Optional[torch.Tensor],
|
||||||
start_seq_q,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
end_seq_q,
|
block_tables: torch.Tensor,
|
||||||
max_s,
|
slots: torch.Tensor,
|
||||||
past_present_indices,
|
input_lengths: torch.Tensor,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
max_s: int,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
hidden_states, present = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
start_seq_prefill,
|
||||||
end_seq,
|
end_seq_prefill,
|
||||||
start_seq_q,
|
kv_cache,
|
||||||
end_seq_q,
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
|
||||||
past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits, present
|
return logits
|
||||||
|
@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashNeoXSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
num_layers=len(model.gpt_neox.layers),
|
||||||
|
num_kv_heads=model.gpt_neox.num_heads,
|
||||||
|
head_size=model.gpt_neox.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
model = FlashRWForCausalLM(config, weights)
|
model = FlashRWForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashRWSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
num_layers=len(model.transformer.h),
|
||||||
|
num_kv_heads=model.transformer.cache_size,
|
||||||
|
head_size=model.transformer.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -62,10 +62,12 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
model = FlashSantacoderForCausalLM(config, weights)
|
model = FlashSantacoderForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashSantacoderSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
num_layers=len(model.transformer.h),
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=model.transformer.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
Loading…
Reference in New Issue
Block a user