From b890c8c47db93829bb3830a88895364aaba561fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 26 Jul 2024 08:51:32 +0000 Subject: [PATCH] Update vLLM dependency to 0.5.3.post1 --- server/Makefile-vllm | 4 +- .../layers/attention/cuda.py | 11 +- .../flash_deepseek_v2_modeling.py | 183 +----------------- 3 files changed, 10 insertions(+), 188 deletions(-) diff --git a/server/Makefile-vllm b/server/Makefile-vllm index f1f80529..0e24bcf6 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,9 +1,9 @@ -commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b +commit_cuda := 1ce677148e4d0252c7f668b87905bcec210207a6 commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ - git clone https://github.com/Narsil/vllm.git vllm; \ + git clone https://github.com/danieldk/vllm.git vllm; \ fi cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 8ff51596..e03d5d30 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -9,8 +9,7 @@ is_sm75 = major == 7 and minor == 5 _PARTITION_SIZE = 512 try: - from vllm._C import cache_ops - from vllm._C import ops + import vllm._custom_ops as ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" @@ -29,8 +28,8 @@ def reshape_and_cache( key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 + ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 ) @@ -114,7 +113,7 @@ def paged_attention( if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths - from vllm._C import ops + import vllm._custom_ops as ops use_v1 = max_s <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512 @@ -134,6 +133,7 @@ def paged_attention( None, "auto", 1.0, + 1.0, ) else: # Run PagedAttention V2. @@ -167,6 +167,7 @@ def paged_attention( None, "auto", 1.0, + 1.0, ) return out diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index f5b2ba0e..93dab033 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed @@ -38,6 +38,7 @@ from text_generation_server.utils.weights import Weights from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig +from vllm.model_executor.layers.fused_moe import fused_experts, grouped_topk class DeepseekV2Config(PretrainedConfig): @@ -798,183 +799,3 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits - - -# Functions below are from vLLM: -# -# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397 -# -# Remove after we have synced our version with upstream. - - -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids - - -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], -) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - if M <= E: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } - return config - - -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -): - # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - - import triton.language as tl - from vllm import _custom_ops as ops - from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_moe_configs, - invoke_fused_moe_kernel, - moe_align_block_size, - ) - - M, _ = hidden_states.shape - E, N, _ = w1.shape - - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config( - M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None - ) - - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - if inplace: - return torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states, - ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)