diff --git a/Dockerfile b/Dockerfile index e79372a3..46089824 100644 --- a/Dockerfile +++ b/Dockerfile @@ -160,11 +160,6 @@ WORKDIR /usr/src COPY server/Makefile-selective-scan Makefile RUN make build-all -# Build megablocks -FROM kernel-builder as megablocks-builder - -RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e - # Text Generation Inference base image FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base @@ -186,9 +181,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ && rm -rf /var/lib/apt/lists/* -# Copy conda with PyTorch and Megablocks installed -COPY --from=megablocks-builder /opt/conda /opt/conda - # Copy build artifacts from flash attention builder COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages diff --git a/server/Makefile b/server/Makefile index da5171b2..32d01709 100644 --- a/server/Makefile +++ b/server/Makefile @@ -17,9 +17,6 @@ gen-server: find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py -install-megablocks: - pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e - install: gen-server pip install pip --upgrade pip install -r requirements_cuda.txt diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 71c6cabe..803b3d1f 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 +flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69 diff --git a/server/Makefile-vllm b/server/Makefile-vllm index c9c1d520..17660e8b 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -4,7 +4,7 @@ vllm-cuda: git clone https://github.com/vllm-project/vllm.git vllm build-vllm-cuda: vllm-cuda - cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc + cd vllm && git fetch && git checkout b7782002e1da25de77e0b1890ff8b72dd4df917c cd vllm && python setup.py build install-vllm-cuda: build-vllm-cuda diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 2e6ae086..4be8b1b9 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -43,7 +43,7 @@ class CacheManager: ] self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.slots = torch.arange( - 0, num_blocks * self.block_size, dtype=torch.int32 + 0, num_blocks * self.block_size, dtype=torch.int64 ).view(num_blocks, self.block_size) def allocate( diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 985bbd8e..9208a595 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -61,6 +61,7 @@ class CohereConfig(PretrainedConfig): attention_bias=False, attention_dropout=0.0, logit_scale=1.0, + use_qk_norm=False, **kwargs, ): self.vocab_size = vocab_size @@ -84,6 +85,7 @@ class CohereConfig(PretrainedConfig): self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.logit_scale = logit_scale + self.use_qk_norm = use_qk_norm super().__init__( pad_token_id=pad_token_id, @@ -175,16 +177,28 @@ class FlashCohereAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = FastRMSNorm.load( + prefix=f"{prefix}.q_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.k_norm = FastRMSNorm.load( + prefix=f"{prefix}.k_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + else: + self.q_norm = None + self.k_norm = None + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=config.attention_bias, ) - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -199,21 +213,25 @@ class FlashCohereAttention(torch.nn.Module): max_s, ): qkv = self.query_key_value(hidden_states) - query, kv = qkv.split( + query, key, value = qkv.split( [ self.head_size * self.num_heads, - 2 * self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, ], dim=1, ) + if self.use_qk_norm: + query = self.q_norm(query) + key = self.k_norm(key) + query = query.view(-1, self.num_heads, self.head_size) - kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + key = key.view(-1, self.num_key_value_heads, self.head_size) + value = key.view(-1, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(query, key, cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -223,8 +241,8 @@ class FlashCohereAttention(torch.nn.Module): # flash attention flash_attn.attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + key, + value, attn_output, cu_seqlen_prefill, max_s, @@ -235,9 +253,9 @@ class FlashCohereAttention(torch.nn.Module): paged_attention.attention( attn_output, query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, + key, + value, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index dd0bcca5..92423d89 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -16,14 +16,13 @@ import torch import torch.distributed -import numpy as np - from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger +from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( FastLinear, @@ -37,14 +36,6 @@ from text_generation_server.utils.layers import ( ) from text_generation_server.utils.log import log_once -HAS_MEGABLOCKS = True -try: - import stk - import megablocks.ops as ops -except ImportError: - logger.warning("Dbrx: megablocks is not installed") - HAS_MEGABLOCKS = False - class DbrxAttentionConfig(PretrainedConfig): def __init__( @@ -384,10 +375,6 @@ class DbrxAttention(torch.nn.Module): weights=weights, bias=False, ) - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -443,7 +430,7 @@ class DbrxAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, @@ -531,18 +518,6 @@ def round_up(x: torch.Tensor, value: int): class BlockSparseMoE(nn.Module): - """ - Built on the paper and library Megablocks as described in - https://arxiv.org/abs/2211.15841. This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accomodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - def __init__(self, prefix, config: DbrxConfig, weights): super().__init__() self.moe_normalize_expert_weights = ( @@ -572,241 +547,40 @@ class BlockSparseMoE(nn.Module): ) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights) - self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights) - self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights) - - self.offsets = None - self.offsets_block_rows = 0 + w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + self.wv1 = torch.cat([w1, v1], dim=1) + self.w2 = ( + _load_experts(config, f"{prefix}.experts.mlp.w2", weights) + .view(self.num_experts, self.ffn_dim, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) self.process_group = weights.process_group - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - self.blocking = 128 - self.quantize_scatter_num_bits = -1 - - def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - assert self.ffn_dim % self.blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_dim // self.blocking - if self.offsets is None or block_rows > self.offsets_block_rows: - self.offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - self.offsets_block_rows = block_rows - offsets = self.offsets - else: - offsets = self.offsets[: block_rows + 1] - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, self.blocking, block_rows, blocks_per_row - ) - - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=x.dtype, - device="meta", - ) - shape = (padded_tokens, self.ffn_dim * self.num_experts) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - False, - False, - False, - ) - - def indices_and_padded_bins(self, selected_experts: torch.Tensor): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # selected_experts = selected_experts.int() - - # returns bin_ids == num of experts for this sequence ? == unique selected experts? - # and indices == how to sort tokens? - bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit) - # bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k] - # indices => [14, 32, 33, ...] => [num_tokens * top_k] - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(selected_experts, self.num_experts) - # tokens_per_expert => [3, 0, 2, ...] => [num_experts] - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - - # List of size num_experts - padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking) - # padded_tokens_per_expert => [128, O, 128, ...] - - # Cumulative selected experts per token - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - # padded_bins => [128, 128, 256, ...] - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - # bins => [3, 3, 5, ...] - - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - selected_experts, weights = select_experts( - gate_logits, self.top_k, self.moe_normalize_expert_weights - ) - - ( - indices, - bin_ids, - bins, - padded_bins, - _, - ) = self.indices_and_padded_bins(selected_experts) - - # Permute tokens and pad to prepare expert computation - # (top_k * sequence_length + padding, model_dim) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k) - - # Create the sparse matrix topology - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation - # First Dense x Dense -> Sparse for w1 and v1, - # (top_k * sequence_length + padding, ffn_dim * n_experts) - x = stk.Matrix( - topo.size(), - self.act(stk.ops.sdd(x, self.w1.t(), topo).data) - * stk.ops.sdd(x, self.v1.t(), topo).data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Then Sparse x Dense -> Dense for w2 - # (top_k * sequence_length + padding, model_dim) - x = stk.ops.dsd(x, self.w2) - - # Permute back and remove padding - # (sequence_length, model_dim) - x = ops.padded_scatter( + def forward(self, x: torch.Tensor) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = fused_moe( x, - indices, - bin_ids, - weights, - bins, - padded_bins, + self.wv1, + self.w2, + router_logits, self.top_k, - self.quantize_scatter_num_bits, - ).view(*input_shape) - - if self.process_group.size() > 1: - torch.distributed.all_reduce(x, group=self.process_group) - - return x.view(*input_shape) - - def dense_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - # all_probs: (sequence_length, n_experts) and upcast for softmax - weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) - - if self.top_k < self.num_experts: - _, not_selected_experts = torch.topk( - weights, - self.num_experts - self.top_k, - largest=False, - sorted=False, - dim=1, - ) - # Mask not selected experts - weights.scatter_(1, not_selected_experts, 0) - - # Re-normalize - if self.moe_normalize_expert_weights: - weights = weights / torch.norm( - weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True - ) - weights = weights.to(x.dtype) - - # Expand to [num_experts, sequence_length, model_dim] - x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) - - # Permute to [num_experts, model_dim, ffn_dim] - w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 + renormalize=self.moe_normalize_expert_weights, + inplace=True, ) - v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 - ) - - inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1) - - out = torch.bmm( - inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim) - ) - # Mask not selected experts - out *= weights.t().view(self.num_experts, -1, 1) - - # Sum experts - out = out.sum(0) # Reduce sum if self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) - return out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if len(x) > 256 and HAS_MEGABLOCKS: - return self.sparse_forward(x) - # This is faster when there is not a lot of tokens - return self.dense_forward(x) + return out.view(*x.shape) class DenseMoE(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index bd7596db..e66c56d1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -188,10 +188,6 @@ class FlashGemmaAttention(torch.nn.Module): weights=weights, bias=False, ) - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -244,7 +240,7 @@ class FlashGemmaAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3a269fc0..64ff6a85 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -176,10 +176,6 @@ class FlashLlamaAttention(torch.nn.Module): weights=weights, bias=False, ) - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -232,7 +228,7 @@ class FlashLlamaAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index ed9306e0..abe74be9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -173,10 +173,6 @@ class MistralAttention(torch.nn.Module): weights=weights, bias=False, ) - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -236,7 +232,7 @@ class MistralAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index d71a3f0c..52ac8fa4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,6 +24,7 @@ import torch.distributed import numpy as np from torch import nn +from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple @@ -41,14 +42,6 @@ from text_generation_server.utils.layers import ( get_linear, ) -HAS_MEGABLOCKS = True -try: - import stk - import megablocks.ops as ops -except ImportError: - logger.warning("Mixtral: megablocks is not installed") - HAS_MEGABLOCKS = False - class MixtralConfig(PretrainedConfig): model_type = "mixtral" @@ -229,10 +222,6 @@ class MixtralAttention(torch.nn.Module): weights=weights, bias=False, ) - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -292,7 +281,7 @@ class MixtralAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, @@ -321,18 +310,6 @@ def round_up(x: torch.Tensor, value: int): class BlockSparseMoE(nn.Module): - """ - Built on the paper and library Megablocks as described in - https://arxiv.org/abs/2211.15841. This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accomodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - def __init__(self, prefix, config: MixtralConfig, weights): super().__init__() self.hidden_dim = config.hidden_size @@ -357,236 +334,40 @@ class BlockSparseMoE(nn.Module): self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights) - self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights) - self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights) - - self.offsets = None - self.offsets_block_rows = 0 + w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + self.w13 = torch.cat([w1, w3], dim=1) + self.w2 = ( + _load_experts(config, f"{prefix}.experts", "w2", weights) + .view(self.num_experts, self.ffn_dim, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) self.process_group = weights.process_group - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - self.blocking = 128 - self.quantize_scatter_num_bits = -1 - - def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - assert self.ffn_dim % self.blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_dim // self.blocking - if self.offsets is None or block_rows > self.offsets_block_rows: - self.offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - self.offsets_block_rows = block_rows - offsets = self.offsets - else: - offsets = self.offsets[: block_rows + 1] - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, self.blocking, block_rows, blocks_per_row - ) - - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=x.dtype, - device="meta", - ) - shape = (padded_tokens, self.ffn_dim * self.num_experts) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - False, - False, - False, - ) - - def indices_and_padded_bins(self, selected_experts: torch.Tensor): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # selected_experts = selected_experts.int() - - # returns bin_ids == num of experts for this sequence ? == unique selected experts? - # and indices == how to sort tokens? - bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit) - # bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k] - # indices => [14, 32, 33, ...] => [num_tokens * top_k] - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(selected_experts, self.num_experts) - # tokens_per_expert => [3, 0, 2, ...] => [num_experts] - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - - # List of size num_experts - padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking) - # padded_tokens_per_expert => [128, O, 128, ...] - - # Cumulative selected experts per token - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - # padded_bins => [128, 128, 256, ...] - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - # bins => [3, 3, 5, ...] - - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - selected_experts, weights = select_experts(gate_logits, self.top_k) - - ( - indices, - bin_ids, - bins, - padded_bins, - _, - ) = self.indices_and_padded_bins(selected_experts) - - # Permute tokens and pad to prepare expert computation - # (top_k * sequence_length + padding, model_dim) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k) - - # Create the sparse matrix topology - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation - # First Dense x Dense -> Sparse for w1 and w3, - # (top_k * sequence_length + padding, ffn_dim * n_experts) - x = stk.Matrix( - topo.size(), - self.act(stk.ops.sdd(x, self.w1.t(), topo).data) - * stk.ops.sdd(x, self.w3.t(), topo).data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Then Sparse x Dense -> Dense for w2 - # (top_k * sequence_length + padding, model_dim) - x = stk.ops.dsd(x, self.w2) - - # Permute back and remove padding - # (sequence_length, model_dim) - x = ops.padded_scatter( + def forward(self, x: torch.Tensor) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = fused_moe( x, - indices, - bin_ids, - weights, - bins, - padded_bins, + self.w13, + self.w2, + router_logits, self.top_k, - self.quantize_scatter_num_bits, - ).view(*input_shape) - - if self.process_group.size() > 1: - torch.distributed.all_reduce(x, group=self.process_group) - - return x.view(*input_shape) - - def dense_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - # all_probs: (sequence_length, n_experts) and upcast for softmax - all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) - - if self.top_k < self.num_experts: - _, not_selected_experts = torch.topk( - all_probs, - self.num_experts - self.top_k, - largest=False, - sorted=False, - dim=1, - ) - # Mask not selected experts - all_probs.scatter_(1, not_selected_experts, 0) - - # Re-normalize - weights = all_probs / all_probs.sum(dim=1, keepdim=True) - weights = weights.to(x.dtype) - - # Expand to [num_experts, sequence_length, model_dim] - x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) - - # Permute to [num_experts, model_dim, ffn_dim] - w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 + renormalize=True, + inplace=True, ) - w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 - ) - - inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3) - - out = torch.bmm( - inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim) - ) - # Mask not selected experts - out *= weights.t().view(self.num_experts, -1, 1) - - # Sum experts - out = out.sum(0) # Reduce sum if self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) - return out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if len(x) > 256 and HAS_MEGABLOCKS: - return self.sparse_forward(x) - # This is faster when there is not a lot of tokens - return self.dense_forward(x) + return out.view(*x.shape) class DenseMoE(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ee062d3d..ad8933a2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -120,9 +120,6 @@ class FlashNeoxAttention(torch.nn.Module): self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) - self.kv_head_mapping = torch.arange( - 0, self.num_heads, dtype=torch.int32, device=weights.device - ) def forward( self, @@ -168,7 +165,7 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0], kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index cfe447a7..48f54e25 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -140,10 +140,6 @@ class FlashPhiAttention(torch.nn.Module): weights=weights, bias=True, ) - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -206,7 +202,7 @@ class FlashPhiAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 94023b33..a8268220 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -104,10 +104,6 @@ class Qwen2Attention(torch.nn.Module): weights=weights, bias=False, ) - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -167,7 +163,7 @@ class Qwen2Attention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index a9127d1f..3ac912f4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -151,15 +151,6 @@ class FlashRWAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) - if self.num_heads_kv == 1: - self.kv_head_mapping = torch.zeros( - self.num_heads, dtype=torch.int32, device=weights.device - ) - else: - self.kv_head_mapping = torch.arange( - 0, self.num_heads, dtype=torch.int32, device=weights.device - ) - def forward( self, hidden_states, @@ -213,7 +204,7 @@ class FlashRWAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_heads_kv, self.softmax_scale, block_tables, input_lengths, @@ -272,10 +263,6 @@ class FlashRWLargeAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) - self.kv_head_mapping = torch.arange( - 0, self.num_groups, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_heads) - def forward( self, hidden_states, @@ -332,7 +319,7 @@ class FlashRWLargeAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_groups, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index bbb603a7..63b458b2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -241,9 +241,6 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) - self.kv_head_mapping = torch.zeros( - self.num_heads, dtype=torch.int32, device=weights.device - ) def forward( self, @@ -292,7 +289,7 @@ class FlashMQAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index ed77af78..63395099 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -190,9 +190,6 @@ class Starcoder2Attention(torch.nn.Module): bias=config.use_bias, ) self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) def forward( self, @@ -252,7 +249,7 @@ class Starcoder2Attention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.kv_head_mapping, + self.num_key_value_heads, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5c25f341..57dd8704 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -690,7 +690,7 @@ class FlashCausalLM(Model): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s block_tables = ( torch.arange(max_bt, dtype=torch.int32, device=self.device) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 2e1055b2..7b990c0a 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -378,7 +378,7 @@ class BaseFlashMistral(FlashCausalLM): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s block_tables = ( torch.arange(max_bt, dtype=torch.int32, device=self.device) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 48f8ef70..45090c64 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -88,6 +88,9 @@ def attention( out, cu_seqlens, cu_seqlens, + None, + None, + None, max_s, max_s, 0.0, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 209f1c8a..ad70651f 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -174,6 +174,8 @@ class EETQLinear(nn.Module): ) -> None: super().__init__() device = weight.device + if weight.dtype != torch.float16: + weight = weight.to(dtype=torch.float16) weight = torch.t(weight).contiguous().cpu() weight, scale = quant_weights(weight, torch.int8, False) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 4b12744c..09e426ae 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,8 +1,7 @@ import torch # vllm imports -from vllm import cache_ops -from vllm import attention_ops +from vllm._C import cache_ops, ops _PARTITION_SIZE = 512 @@ -14,7 +13,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) def attention( @@ -22,7 +21,7 @@ def attention( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, + num_key_value_heads: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, input_lengths: torch.Tensor, @@ -54,20 +53,22 @@ def attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: - attention_ops.paged_attention_v1( + ops.paged_attention_v1( out, query, key_cache, value_cache, - kv_head_mapping, + num_key_value_heads, softmax_scale, block_tables, input_lengths, block_size, max_s, None, + "auto", + 1.0, ) else: # Run PagedAttention V2. @@ -83,7 +84,7 @@ def attention( device=out.device, ) max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( + ops.paged_attention_v2( out, exp_sums, max_logits, @@ -91,11 +92,13 @@ def attention( query, key_cache, value_cache, - kv_head_mapping, + num_key_value_heads, softmax_scale, block_tables, input_lengths, block_size, max_s, None, + "auto", + 1.0, )