diff --git a/Dockerfile b/Dockerfile index 2a313c25..1a969383 100644 --- a/Dockerfile +++ b/Dockerfile @@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \ /opt/conda/bin/conda clean -ya - # Build Flash Attention CUDA kernels FROM kernel-builder as flash-att-builder @@ -109,6 +108,16 @@ COPY server/custom_kernels/ . # Build specific version of transformers RUN python setup.py build +# Build vllm CUDA kernels +FROM kernel-builder as vllm-builder + +WORKDIR /usr/src + +COPY server/Makefile-vllm Makefile + +# Build specific version of vllm +RUN make build-vllm + # Text Generation Inference base image FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base @@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages -# Copy build artifacts from transformers builder +# Copy build artifacts from custom kernels builder COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels +# Copy builds artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/router/src/infer.rs b/router/src/infer.rs index 8d93d2a1..d0d22d3b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -290,7 +290,6 @@ async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - tracing::info!("{token_budget} {batch_max_tokens}"); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue diff --git a/router/src/main.rs b/router/src/main.rs index 474f4e06..47d48e3f 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -180,7 +180,11 @@ fn main() -> Result<(), std::io::Error> { let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await .expect("Could not connect to server"); - + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .expect("Unable to clear cache"); // Get info from the shard let shard_info = sharded_client .info() diff --git a/router/src/queue.rs b/router/src/queue.rs index a3a607e7..75009fcd 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -305,7 +305,7 @@ mod tests { watermark: false, }, stopping_parameters: StoppingCriteriaParameters { - ignore_eos_token: true, + ignore_eos_token: false, max_new_tokens: 1, stop_sequences: vec![], }, diff --git a/router/src/server.rs b/router/src/server.rs index 04d1269b..95d4d4c1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -152,7 +152,7 @@ async fn generate( let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); - // tracing::debug!("Input: {}", req.0.inputs); + tracing::debug!("Input: {}", req.0.inputs); let compute_characters = req.0.inputs.chars().count(); let mut add_prompt = None; @@ -286,7 +286,7 @@ async fn generate( } tracing::debug!("Output: {}", output_text); - // tracing::info!("Success"); + tracing::info!("Success"); let response = GenerateResponse { generated_text: output_text, diff --git a/server/Makefile-vllm b/server/Makefile-vllm new file mode 100644 index 00000000..b9725ba3 --- /dev/null +++ b/server/Makefile-vllm @@ -0,0 +1,13 @@ +vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9 + +vllm: + # Clone vllm + git clone https://github.com/OlivierDehaene/vllm.git + +build-vllm: vllm + cd vllm && git fetch && git checkout $(flash_att_commit) + cd vllm && python setup.py build + +install-vllm: build-vllm + pip uninstall vllm -y || true + cd vllm && python setup.py install \ No newline at end of file diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index bb7fcbef..07765e88 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -24,13 +24,15 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from vllm import attention_ops -from vllm import cache_ops # Flash attention imports import flash_attn_cuda import dropout_layer_norm +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -124,6 +126,9 @@ class FlashLlamaAttention(torch.nn.Module): weights=weights, bias=False, ) + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, @@ -145,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - cache_ops.reshape_and_cache( + vllm_cache_ops.reshape_and_cache( qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots ) @@ -178,11 +183,12 @@ class FlashLlamaAttention(torch.nn.Module): else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] block_size = kv_cache[1].shape[3] - attention_ops.single_query_cached_kv_attention( + vllm_attention_ops.single_query_cached_kv_attention( attn_output, qkv[:, 0], kv_cache[0], kv_cache[1], + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 9c1020a5..9049878a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -25,11 +25,15 @@ from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module): self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = qkv[:, 1:] + vllm_cache_ops.reshape_and_cache( + qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + ) - # output - attn_output = torch.empty_like(qkv[:, 0]) + # output tensor + attn_output = torch.empty_like(qkv[:, 0]) + + # Prefill + if start_seq_prefill is not None: # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = qkv[:, 1:] - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - layer_past[:, 0], - layer_past[:, 1], + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + qkv[:, 0], + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module): ln1_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module): hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( @@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - past_key_values[:, i], - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.final_layer_norm(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): @@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.gpt_neox( + ) -> torch.Tensor: + hidden_states = self.gpt_neox( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.embed_out(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fa35c359..44aa7cb1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -4,11 +4,15 @@ import torch.distributed from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) + if self.num_heads_kv == 1: + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) + else: + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) @@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = kv - # Expand to query shape - kv = kv.expand(-1, 2, self.num_heads, self.head_size) + vllm_cache_ops.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output + attn_output = torch.empty_like(query) + + # Prefill + if start_seq_prefill is not None: + if self.num_heads_kv == 1: + # Expand to query shape + kv = kv.expand(-1, 2, self.num_heads, self.head_size) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = kv - # Expand to query shape - kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) + self.kv_head_mapping = torch.arange( + 0, self.num_groups, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_heads) + def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) + vllm_cache_ops.reshape_and_cache( + kv[:, :, 0].contiguous(), + kv[:, :, 1].contiguous(), + kv_cache[0], + kv_cache[1], + slots, + ) + + # output + attn_output = torch.empty_like(query) + # Prefill - if prefill: - # Copy to layer past - layer_past[...] = kv + if start_seq_prefill is not None: # Expand to query shape kv = ( kv.unsqueeze(2) @@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module): .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) ) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = kv - # Expand to query shape - kv = ( - layer_past.unsqueeze(2) - .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) - .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) - ) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(kv, dim=2, index=0), - torch.select(kv, dim=2, index=1), + # kv_cache[1] => [num_blocks, num_groups, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense( @@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module): ln_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) mlp_output = self.mlp(ln_hidden_states) @@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module): hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): ln_attn, residual = self.ln_attn(hidden_states, residual) ln_mlp, _ = self.ln_mlp(residual) @@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module): ln_attn, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) # MLP. @@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel): for layer_id in range(config.num_hidden_layers) ] ) - self.cache_size = ( - 2, - self.h[0].self_attention.num_heads_kv, - self.h[0].self_attention.head_size, - ) + self.cache_size = self.h[0].self_attention.num_heads_kv elif config.model_type == "RefinedWeb": self.h = nn.ModuleList( [ @@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel): for layer_id in range(config.num_hidden_layers) ] ) - self.cache_size = ( - self.h[0].self_attention.num_groups, - 2, - self.h[0].self_attention.head_size, - ) + self.cache_size = self.h[0].self_attention.num_groups else: raise NotImplementedError( f"model_type {config.model_type} is not supported." @@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.h), - *self.cache_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( @@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - torch.select(past_key_values, dim=1, index=i), - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.h), - *self.cache_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.ln_f(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashRWForCausalLM(FlashRWPreTrainedModel): @@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.transformer( + ) -> torch.Tensor: + hidden_states = self.transformer( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 4eb0034d..04eedef7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -3,11 +3,15 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, hidden_states, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.c_attn(hidden_states) @@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) + vllm_cache_ops.reshape_and_cache( + key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output + attn_output = torch.empty_like(query) + # Prefill - if prefill: - # Copy to layer past - layer_past[...] = key_value + if start_seq_prefill is not None: # Expand from 1 to num_heads key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = key_value - # Expand from 1 to num_heads - key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(key_value, dim=1, index=0), - torch.select(key_value, dim=1, index=1), + # kv_cache[1] => [num_blocks, 1, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -361,27 +357,25 @@ class Block(nn.Module): self, hidden_states, residual, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( hidden_states, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.process_group.size() > 1: torch.distributed.all_reduce(hidden_states, group=self.process_group) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_zeros( - (len(input_ids), len(self.h), 2, 1, self.head_size) - ) - # Decode - else: - prefill = False - residual = None for i, layer in enumerate(self.h): hidden_states, residual = layer( hidden_states, residual, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - torch.select(past_key_values, dim=1, index=i), - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - (pre_allocate_past_size, len(self.h), 2, 1, self.head_size) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.ln_f(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashSantacoderForCausalLM(nn.Module): @@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.transformer( + ) -> torch.Tensor: + hidden_states = self.transformer( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 4847571d..e64af0c6 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM): model = FlashGPTNeoXForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashNeoXSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.gpt_neox.layers), + num_kv_heads=model.gpt_neox.num_heads, + head_size=model.gpt_neox.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 5f963bfb..a55f9118 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM): model = FlashRWForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashRWSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.transformer.h), + num_kv_heads=model.transformer.cache_size, + head_size=model.transformer.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index f4363e19..ef202785 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -62,10 +62,12 @@ class FlashSantacoderSharded(FlashCausalLM): model = FlashSantacoderForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashSantacoderSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.transformer.h), + num_kv_heads=1, + head_size=model.transformer.head_size, dtype=dtype, device=device, rank=rank,