mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
add falcon, santacoder and neox support
This commit is contained in:
parent
ddfc02f2a4
commit
16f796f735
16
Dockerfile
16
Dockerfile
@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
|
||||
# Build Flash Attention CUDA kernels
|
||||
FROM kernel-builder as flash-att-builder
|
||||
|
||||
@ -109,6 +108,16 @@ COPY server/custom_kernels/ .
|
||||
# Build specific version of transformers
|
||||
RUN python setup.py build
|
||||
|
||||
# Build vllm CUDA kernels
|
||||
FROM kernel-builder as vllm-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-vllm Makefile
|
||||
|
||||
# Build specific version of vllm
|
||||
RUN make build-vllm
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
|
||||
|
||||
@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
||||
# Copy build artifacts from transformers builder
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
|
@ -290,7 +290,6 @@ async fn batching_task(
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
tracing::info!("{token_budget} {batch_max_tokens}");
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
|
@ -180,7 +180,11 @@ fn main() -> Result<(), std::io::Error> {
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client
|
||||
.info()
|
||||
|
@ -305,7 +305,7 @@ mod tests {
|
||||
watermark: false,
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: true,
|
||||
ignore_eos_token: false,
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
|
@ -152,7 +152,7 @@ async fn generate(
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
// tracing::debug!("Input: {}", req.0.inputs);
|
||||
tracing::debug!("Input: {}", req.0.inputs);
|
||||
|
||||
let compute_characters = req.0.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
@ -286,7 +286,7 @@ async fn generate(
|
||||
}
|
||||
|
||||
tracing::debug!("Output: {}", output_text);
|
||||
// tracing::info!("Success");
|
||||
tracing::info!("Success");
|
||||
|
||||
let response = GenerateResponse {
|
||||
generated_text: output_text,
|
||||
|
13
server/Makefile-vllm
Normal file
13
server/Makefile-vllm
Normal file
@ -0,0 +1,13 @@
|
||||
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
|
||||
|
||||
vllm:
|
||||
# Clone vllm
|
||||
git clone https://github.com/OlivierDehaene/vllm.git
|
||||
|
||||
build-vllm: vllm
|
||||
cd vllm && git fetch && git checkout $(flash_att_commit)
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm: build-vllm
|
||||
pip uninstall vllm -y || true
|
||||
cd vllm && python setup.py install
|
@ -24,13 +24,15 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
from vllm import attention_ops
|
||||
from vllm import cache_ops
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
import dropout_layer_norm
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -124,6 +126,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -145,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
@ -178,11 +183,12 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
attn_output,
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -25,11 +25,15 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||
)
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[...] = qkv[:, 1:]
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
|
||||
# Prefill
|
||||
if start_seq_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
query = qkv[:, 0]
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[past_present_indices] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
layer_past[:, 0],
|
||||
layer_past[:, 1],
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
attn_output,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module):
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
if self.use_parallel_residual:
|
||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||
@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module):
|
||||
ln1_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||
@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module):
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
|
||||
# Prefill
|
||||
if past_key_values is None:
|
||||
assert pre_allocate_past_size is not None
|
||||
|
||||
prefill = True
|
||||
|
||||
# Create past tensor
|
||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
len(input_ids),
|
||||
len(self.layers),
|
||||
2,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
)
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
prefill = False
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
|
||||
@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
past_key_values[:, i],
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
present = past_key_values
|
||||
# Create padded past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(self.layers),
|
||||
2,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
)
|
||||
)
|
||||
# We slice only once instead of at every layer
|
||||
past_key_values[past_present_indices] = present
|
||||
|
||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states, present = self.gpt_neox(
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.embed_out(hidden_states)
|
||||
return logits, present
|
||||
return logits
|
||||
|
@ -4,11 +4,15 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module):
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
if self.num_heads_kv == 1:
|
||||
self.kv_head_mapping = torch.zeros(
|
||||
self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
else:
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if start_seq_prefill is not None:
|
||||
if self.num_heads_kv == 1:
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
attn_output,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_groups, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||
@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
kv[:, :, 0].contiguous(),
|
||||
kv[:, :, 1].contiguous(),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
slots,
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
if start_seq_prefill is not None:
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
kv.unsqueeze(2)
|
||||
@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
attn_output,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
layer_past.unsqueeze(2)
|
||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
attn_output,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.dense(
|
||||
@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module):
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module):
|
||||
ln_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(ln_hidden_states)
|
||||
@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module):
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module):
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(residual)
|
||||
@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module):
|
||||
ln_attn,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = (
|
||||
2,
|
||||
self.h[0].self_attention.num_heads_kv,
|
||||
self.h[0].self_attention.head_size,
|
||||
)
|
||||
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||
elif config.model_type == "RefinedWeb":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = (
|
||||
self.h[0].self_attention.num_groups,
|
||||
2,
|
||||
self.h[0].self_attention.head_size,
|
||||
)
|
||||
self.cache_size = self.h[0].self_attention.num_groups
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"model_type {config.model_type} is not supported."
|
||||
@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
|
||||
# Prefill
|
||||
if past_key_values is None:
|
||||
assert pre_allocate_past_size is not None
|
||||
|
||||
prefill = True
|
||||
|
||||
# Create past tensor
|
||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
len(input_ids),
|
||||
len(self.h),
|
||||
*self.cache_size,
|
||||
)
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
prefill = False
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
|
||||
@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
torch.select(past_key_values, dim=1, index=i),
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
present = past_key_values
|
||||
# Create padded past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(self.h),
|
||||
*self.cache_size,
|
||||
)
|
||||
)
|
||||
# We slice only once instead of at every layer
|
||||
past_key_values[past_present_indices] = present
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits, present
|
||||
return logits
|
||||
|
@ -3,11 +3,15 @@ import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module):
|
||||
self.c_proj = load_row(
|
||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||
)
|
||||
self.kv_head_mapping = torch.zeros(
|
||||
self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
qkv = self.c_attn(hidden_states)
|
||||
|
||||
@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[...] = key_value
|
||||
if start_seq_prefill is not None:
|
||||
# Expand from 1 to num_heads
|
||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
attn_output,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[past_present_indices] = key_value
|
||||
# Expand from 1 to num_heads
|
||||
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
attn_output,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq,
|
||||
end_seq,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
@ -361,27 +357,25 @@ class Block(nn.Module):
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
|
||||
hidden_states = self.attn(
|
||||
hidden_states,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
layer_past,
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||
@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||
|
||||
# Prefill
|
||||
if past_key_values is None:
|
||||
assert pre_allocate_past_size is not None
|
||||
|
||||
prefill = True
|
||||
|
||||
# Create past tensor
|
||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||
past_key_values = hidden_states.new_zeros(
|
||||
(len(input_ids), len(self.h), 2, 1, self.head_size)
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
prefill = False
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.h):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
torch.select(past_key_values, dim=1, index=i),
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
present = past_key_values
|
||||
# Create padded past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
||||
)
|
||||
# We slice only once instead of at every layer
|
||||
past_key_values[past_present_indices] = present
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashSantacoderForCausalLM(nn.Module):
|
||||
@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq,
|
||||
end_seq,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
past_present_indices,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits, present
|
||||
return logits
|
||||
|
@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
super(FlashNeoXSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
num_layers=len(model.gpt_neox.layers),
|
||||
num_kv_heads=model.gpt_neox.num_heads,
|
||||
head_size=model.gpt_neox.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
|
@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM):
|
||||
model = FlashRWForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
super(FlashRWSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=model.transformer.cache_size,
|
||||
head_size=model.transformer.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
|
@ -62,10 +62,12 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
model = FlashSantacoderForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
super(FlashSantacoderSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=1,
|
||||
head_size=model.transformer.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
|
Loading…
Reference in New Issue
Block a user