This commit is contained in:
OlivierDehaene 2023-12-10 10:38:23 +01:00
parent 9ecfa16b12
commit af1989459c
11 changed files with 863 additions and 217 deletions

View File

@ -212,6 +212,8 @@ RUN cd server && \
pip install -r requirements_cuda.txt && \ pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir
RUN pip install git+https://github.com/OlivierDehaene/megablocks#33fad2b0eae7c47b8fedfb3ad415af8169386918
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router

View File

@ -188,7 +188,10 @@ def download_weights(
# Try to see if there are local pytorch weights # Try to see if there are local pytorch weights
try: try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = utils.weight_files(model_id, revision, ".bin") try:
local_pt_files = utils.weight_files(model_id, revision, extension=".bin")
except (FileNotFoundError, utils.EntryNotFoundError):
local_pt_files = utils.weight_files(model_id, revision, extension=".pt")
# No local pytorch weights # No local pytorch weights
except utils.LocalEntryNotFoundError: except utils.LocalEntryNotFoundError:
@ -199,7 +202,10 @@ def download_weights(
) )
# Try to see if there are pytorch weights on the hub # Try to see if there are pytorch weights on the hub
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") try:
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".bin")
except utils.EntryNotFoundError:
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".pt")
# Download pytorch weights # Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision) local_pt_files = utils.download_weights(pt_filenames, model_id, revision)

View File

@ -1,4 +1,3 @@
import os
import torch import torch
from loguru import logger from loguru import logger
@ -287,6 +286,7 @@ def get_model(
if MISTRAL: if MISTRAL:
return FlashMistral( return FlashMistral(
model_id, model_id,
config_dict.get("architectures", []),
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,

View File

@ -34,14 +34,8 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm
) )
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
def __init__( def __init__(
@ -95,75 +89,6 @@ class LlamaConfig(PretrainedConfig):
) )
class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
@ -363,10 +288,8 @@ class FlashLlamaLayer(nn.Module):
) )
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = LlamaRMSNorm( self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps self.post_attention_layernorm = FastRMSNorm.load(
)
self.post_attention_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
@ -430,7 +353,7 @@ class FlashLlamaModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = LlamaRMSNorm( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )

View File

@ -35,13 +35,9 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm
) )
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM: if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2") raise ImportError("Mistral model requires flash attn v2")
@ -100,76 +96,6 @@ class MistralConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
class MistralRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
@ -371,10 +297,10 @@ class MistralLayer(nn.Module):
) )
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = MistralRMSNorm( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
self.post_attention_layernorm = MistralRMSNorm( self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
@ -440,7 +366,7 @@ class MistralModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = MistralRMSNorm( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )

View File

@ -0,0 +1,688 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
import numpy as np
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
from text_generation_server.utils.layers import (
FastLinear,
FastRMSNorm,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
)
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mixtral model requires flash attn v2")
try:
import megablocks.ops as ops
except ImportError:
raise ImportError("Mixtral model requires megablocks to be installed")
try:
import stk
except ImportError:
raise ImportError("Mixtral model requires stk to be installed")
class MixtralConfig(PretrainedConfig):
model_type = "mistral"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
num_experts_per_tok=2,
num_local_experts=8,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = 4
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
def _load_experts(config, prefix, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
slice_ = weights._get_slice(prefix)
world_size = weights.process_group.size()
rank = weights.process_group.rank()
if world_size == 1:
tensor = slice_[:].to(dtype=weights.dtype).to(device=weights.device)
else:
assert (
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
assert slice_.get_shape()[0] == config.num_local_experts * config.intermediate_size
block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
expert_slices = []
for i in range(config.num_local_experts):
expert_start = i * config.intermediate_size
expert_slices.append(slice_[start + expert_start:stop + expert_start])
tensor = torch.cat(expert_slices, dim=0).to(dtype=weights.dtype).to(device=weights.device)
return tensor
class MixtralAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else 0
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size ** -0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.wo = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.wo",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.wo(attn_output.view(-1, self.num_heads * self.head_size))
class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.w1", weights)
self.w2 = _load_experts(config, f"{prefix}.w2", weights)
self.w3 = _load_experts(config, f"{prefix}.w3", weights)
self.process_group = weights.process_group
# Calculate the number of bits needed to represent the expert indices
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
self.blocking = 128
self.quantize_scatter_num_bits = -1
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
blocks_per_row)
# TODO(tgale): This is unused. Remove the need for this in stk.
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
False,
False,
False,
)
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = ops.round_up(tokens_per_expert,
self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# padded_bins => [128, 128, 256, ...]
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
# bins => [3, 3, 5, ...]
return indices, bin_ids, bins, padded_bins, tokens_per_expert
@torch.inference_mode()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.flatten().to(x.dtype)
selected_experts = selected_experts.flatten()
(
indices,
bin_ids,
bins,
padded_bins,
_,
) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
self.top_k, x.shape[0] + self.num_experts * self.blocking)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and w3,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data) *
stk.ops.sdd(x, self.w3.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
# Permute back and remove padding
# (sequence_length, model_dim)
x = ops.padded_scatter(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
self.top_k,
self.quantize_scatter_num_bits,
).view(*input_shape)
if self.process_group.size() > 1:
torch.distributed.all_reduce(x, group=self.process_group)
return x.view(*input_shape)
class MixtralLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"layers.{layer_id}"
self.attention = MixtralAttention(
prefix=f"{prefix}.attention", config=config, weights=weights
)
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
self.attention_norm = FastRMSNorm.load(
prefix=f"{prefix}.attention_norm", weights=weights, eps=config.rms_norm_eps
)
self.ffn_norm = FastRMSNorm.load(
prefix=f"{prefix}.ffn_norm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.attention_norm(hidden_states, residual)
# Self Attention
attn_output = self.attention(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.ffn_norm(
attn_output, res
)
mlp_output = self.block_sparse_moe(normed_attn_res_output)
return mlp_output, attn_res
class MixtralModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tok_embeddings = TensorParallelEmbedding(
prefix="tok_embeddings", weights=weights
)
self.layers = nn.ModuleList(
[
MixtralLayer(
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix="norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
self.num_key_value_heads = self.layers[0].attention.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.tok_embeddings(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = MixtralModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="output",
weights=weights,
)
self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
else:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(self.max_past, max_s)
input_lengths = torch.clamp(input_lengths, max=self.max_past)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

View File

@ -6,7 +6,6 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,

View File

@ -8,14 +8,13 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.models.llama import LlamaTokenizerFast from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type, List
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
from text_generation_server.models.cache_manager import ( from text_generation_server.models.cache_manager import (
get_cache_manager, get_cache_manager,
set_cache_manager,
) )
from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM, FlashMistralForCausalLM,
@ -282,6 +281,7 @@ class FlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
architectures: List[str],
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
@ -305,7 +305,15 @@ class FlashMistral(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = MistralConfig.from_pretrained( if "MixtralForCausalLM" in architectures:
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
config_cls = MixtralConfig
model_cls = FlashMixtralForCausalLM
else:
config_cls = MistralConfig
model_cls = FlashMistralForCausalLM
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
@ -321,7 +329,7 @@ class FlashMistral(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)
model = FlashMistralForCausalLM(config, weights) model = model_cls(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashMistral, self).__init__( super(FlashMistral, self).__init__(

View File

@ -98,7 +98,10 @@ def weight_files(
if extension != ".safetensors": if extension != ".safetensors":
raise e raise e
# Try to see if there are pytorch weights # Try to see if there are pytorch weights
try:
pt_filenames = weight_hub_files(model_id, revision, extension=".bin") pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
except EntryNotFoundError:
pt_filenames = weight_hub_files(model_id, revision, extension=".pt")
# Change pytorch extension to safetensors extension # Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the # It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them # hub if we converted weights locally without pushing them

View File

@ -47,12 +47,14 @@ elif CAN_EXLLAMA:
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
HAS_EXLLAMA = "2" HAS_EXLLAMA = "2"
else: else:
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear, from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
HAS_EXLLAMA = "1" HAS_EXLLAMA = "1"
except ImportError: except ImportError:
@ -526,9 +528,12 @@ class TensorParallelEmbedding(nn.Module):
try: try:
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
import dropout_layer_norm import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
else: else:
dropout_layer_norm = None dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
@ -563,10 +568,81 @@ try:
residual = hidden_states residual = hidden_states
return normed_hidden_states, residual return normed_hidden_states, residual
class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float):
super().__init__()
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
except ImportError: except ImportError:
pass pass
try: try:
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
@ -574,12 +650,14 @@ try:
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
) )
return inv_freq return inv_freq
def _get_rope_config(config): def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None: if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = { rope_scaling = {
@ -589,6 +667,7 @@ try:
return rope_scaling return rope_scaling
return getattr(config, "rope_scaling", None) return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module): class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor): def __init__(self, inv_freq, scaling_factor):
super().__init__() super().__init__()
@ -606,12 +685,12 @@ try:
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim] q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim] q2 = query[..., rotary_dim: 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim] k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim] k2 = key[..., rotary_dim: 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
@ -630,7 +709,8 @@ try:
True True
) )
else: else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
@classmethod @classmethod
def static(cls, config, dim, base, device): def static(cls, config, dim, base, device):
@ -747,6 +827,7 @@ try:
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
@ -783,8 +864,11 @@ try:
# Inverse dim formula to find dim based on number of rotations # Inverse dim formula to find dim based on number of rotations
import math import math
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
# Find dim range bounds based on rotations # Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
@ -792,7 +876,8 @@ try:
low_rot, dim, base, max_position_embeddings)) low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim( high = math.ceil(find_correction_dim(
high_rot, dim, base, max_position_embeddings)) high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) # Clamp values just in case return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim): def linear_ramp_mask(min, max, dim):
if min == max: if min == max:
@ -802,13 +887,16 @@ try:
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
def get_mscale(scale=1): def get_mscale(scale=1):
if scale <= 1: if scale <= 1:
return 1.0 return 1.0
return 0.1 * math.log(scale) + 1.0 return 0.1 * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor,
attn_factor, beta_fast, beta_slow):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor)
self.dim = dim self.dim = dim
@ -818,7 +906,8 @@ try:
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation self.mscale = float(get_mscale(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
@ -834,13 +923,15 @@ try:
) )
freqs = 1.0 / inv_freq_extrapolation freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings) low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base,
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation self.max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(
device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation self.mscale = float(get_mscale(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)

View File

@ -23,7 +23,7 @@ class Weights:
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
for k in f.keys(): for k in f.keys():
if k in routing: if k in routing:
raise RuntimeError( logger.warning(
f"Key {k} was found in multiple files: {filename} and {routing[k]}" f"Key {k} was found in multiple files: {filename} and {routing[k]}"
) )
routing[k] = filename routing[k] = filename
@ -116,7 +116,7 @@ class Weights:
size = slice_.get_shape()[dim] size = slice_.get_shape()[dim]
assert ( assert (
size % world_size == 0 size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The chosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str): def _get_qweight(self, name: str):