mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat(models): Add DBRX
This commit is contained in:
parent
2c83d09d3b
commit
dcfefc425a
@ -35,6 +35,7 @@ from text_generation_server.utils.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
HAS_MEGABLOCKS = True
|
||||
try:
|
||||
@ -176,33 +177,103 @@ def _load_gqa(config, prefix: str, weights):
|
||||
assert config.d_model % config.n_heads == 0
|
||||
assert config.n_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix=f"{prefix}.Wqkv",
|
||||
quantize=config.quantize,
|
||||
head_dim = config.d_model // config.n_heads
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
q_block_size = config.d_model // world_size
|
||||
q_start = rank * q_block_size
|
||||
q_stop = (rank + 1) * q_block_size
|
||||
|
||||
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
|
||||
k_offset = config.d_model
|
||||
k_start = k_offset + rank * kv_block_size
|
||||
k_stop = k_offset + (rank + 1) * kv_block_size
|
||||
|
||||
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
|
||||
v_start = v_offset + rank * kv_block_size
|
||||
v_stop = v_offset + (rank + 1) * kv_block_size
|
||||
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
try:
|
||||
qweight_slice = weights._get_slice(f"{prefix}.qweight")
|
||||
q_qweight = qweight_slice[:, q_start:q_stop]
|
||||
k_qweight = qweight_slice[:, k_start:k_stop]
|
||||
v_qweight = qweight_slice[:, v_start:v_stop]
|
||||
|
||||
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
|
||||
q_qzeros = qzeros_slice[:, q_start:q_stop]
|
||||
k_qzeros = qzeros_slice[:, k_start:k_stop]
|
||||
v_qzeros = qzeros_slice[:, v_start:v_stop]
|
||||
|
||||
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
|
||||
|
||||
scales_slice = weights._get_slice(f"{prefix}.scales")
|
||||
q_scales = scales_slice[:, q_start:q_stop]
|
||||
k_scales = scales_slice[:, k_start:k_stop]
|
||||
v_scales = scales_slice[:, v_start:v_stop]
|
||||
|
||||
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
|
||||
|
||||
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
|
||||
)
|
||||
|
||||
if config.quantize == "gptq" and quant_method == "gptq":
|
||||
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
|
||||
q_g_idx = g_idx_slice[:, q_start:q_stop]
|
||||
k_g_idx = g_idx_slice[:, k_start:k_stop]
|
||||
v_g_idx = g_idx_slice[:, v_start:v_stop]
|
||||
|
||||
w = [q_g_idx, k_g_idx, v_g_idx]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif config.quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
|
||||
// groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
|
||||
q = qkv_slice[q_start:q_stop]
|
||||
k = qkv_slice[k_start:k_stop]
|
||||
v = qkv_slice[v_start:v_stop]
|
||||
|
||||
weight = torch.cat([q, k, v], dim=0)
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.d_model // config.n_heads
|
||||
num_heads = config.n_heads // weights.process_group.size()
|
||||
num_key_value_heads = (
|
||||
config.attn_config.kv_n_heads // weights.process_group.size()
|
||||
)
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.d_model,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.attn_config.kv_n_heads) * head_size, config.d_model]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
def _load_experts(config, prefix, weights):
|
||||
if config.quantize is not None:
|
||||
raise NotImplementedError("Dbrx does not support weight quantization yet.")
|
||||
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
@ -221,9 +292,9 @@ def _load_experts(config, prefix, weights):
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.weight")
|
||||
slice_ = weights._get_slice(f"{prefix}")
|
||||
|
||||
for i in range(config.num_local_experts):
|
||||
for i in range(config.ffn_config.moe_num_experts):
|
||||
offset = i * expert_size
|
||||
expert_slice = slice_[start + offset : stop + offset]
|
||||
|
||||
@ -233,6 +304,46 @@ def _load_experts(config, prefix, weights):
|
||||
return tensor
|
||||
|
||||
|
||||
def _load_experts_quantized(config, prefix, weights, cls):
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
assert (
|
||||
config.ffn_config.ffn_hidden_size % world_size == 0
|
||||
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
expert_size = config.ffn_config.ffn_hidden_size
|
||||
block_size = expert_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}")
|
||||
|
||||
experts = []
|
||||
for i in range(config.ffn_config.moe_num_experts):
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
raise NotImplementedError(
|
||||
"Dbrx does not support gptq/awq quantization yet."
|
||||
)
|
||||
else:
|
||||
offset = i * expert_size
|
||||
expert_slice = (
|
||||
slice_[start + offset : stop + offset]
|
||||
.to(dtype=weights.dtype)
|
||||
.to(device=weights.device)
|
||||
)
|
||||
|
||||
if cls == TensorParallelRowLinear:
|
||||
expert_slice = expert_slice.t().contiguous()
|
||||
linear = get_linear(expert_slice, None, config.quantize)
|
||||
experts.append(cls(linear, weights.process_group))
|
||||
else:
|
||||
linear = get_linear(expert_slice, None, config.quantize)
|
||||
experts.append(cls(linear))
|
||||
|
||||
return experts
|
||||
|
||||
|
||||
class DbrxAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -391,9 +502,7 @@ class DbrxNormAttentionNorm(nn.Module):
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
|
||||
|
||||
return normed_attn_res_output, attn_res
|
||||
|
||||
@ -663,6 +772,7 @@ class BlockSparseMoE(nn.Module):
|
||||
weights = weights / torch.norm(
|
||||
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||
)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Expand to [num_experts, sequence_length, model_dim]
|
||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||
@ -703,8 +813,6 @@ class DenseMoE(nn.Module):
|
||||
def __init__(self, prefix, config: DbrxConfig, weights):
|
||||
super().__init__()
|
||||
|
||||
raise NotImplementedError("Quantization is not implemented for Dbrx")
|
||||
|
||||
self.moe_normalize_expert_weights = (
|
||||
config.ffn_config.moe_normalize_expert_weights
|
||||
)
|
||||
@ -731,24 +839,24 @@ class DenseMoE(nn.Module):
|
||||
config, f"{prefix}.router.layer", weights, bias=False
|
||||
)
|
||||
|
||||
self.w1 = [
|
||||
TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
|
||||
self.w1 = _load_experts_quantized(
|
||||
config,
|
||||
prefix=f"{prefix}.experts.mlp.w1",
|
||||
weights=weights,
|
||||
cls=TensorParallelColumnLinear,
|
||||
)
|
||||
for i in range(self.num_experts)
|
||||
]
|
||||
self.w3 = [
|
||||
TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
|
||||
self.w2 = _load_experts_quantized(
|
||||
config,
|
||||
prefix=f"{prefix}.experts.mlp.w2",
|
||||
weights=weights,
|
||||
cls=TensorParallelRowLinear,
|
||||
)
|
||||
for i in range(self.num_experts)
|
||||
]
|
||||
self.w2 = [
|
||||
TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
|
||||
self.v1 = _load_experts_quantized(
|
||||
config,
|
||||
prefix=f"{prefix}.experts.mlp.v1",
|
||||
weights=weights,
|
||||
cls=TensorParallelColumnLinear,
|
||||
)
|
||||
for i in range(self.num_experts)
|
||||
]
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
@ -764,26 +872,30 @@ class DenseMoE(nn.Module):
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
all_probs,
|
||||
weights,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
all_probs.scatter_(1, not_selected_experts, 0)
|
||||
weights.scatter_(1, not_selected_experts, 0)
|
||||
|
||||
# Re-normalize
|
||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||
if self.moe_normalize_expert_weights:
|
||||
weights = weights / torch.norm(
|
||||
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||
)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Final output tensor
|
||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||
for i in range(self.num_experts):
|
||||
h = self.act(self.w1[i](x)) * self.w3[i](x)
|
||||
h = self.act(self.w1[i](x)) * self.v1[i](x)
|
||||
h = self.w2[i](h, reduce=False)
|
||||
# Add expert output to out with masking
|
||||
out += h * weights[:, i].view(-1, 1)
|
||||
@ -821,7 +933,7 @@ class DbrxLayer(nn.Module):
|
||||
max_s,
|
||||
):
|
||||
# Self Attention
|
||||
attn_output, attn_res = self.self_attn(
|
||||
attn_output, attn_res = self.attn(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
@ -861,9 +973,9 @@ class DbrxModel(torch.nn.Module):
|
||||
prefix="transformer.norm_f", weights=weights, eps=1e-5
|
||||
)
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
self.head_size = self.layers[0].attn.self_attn.head_size
|
||||
self.num_heads = self.layers[0].attn.self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -880,7 +992,7 @@ class DbrxModel(torch.nn.Module):
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
|
@ -552,6 +552,7 @@ class BlockSparseMoE(nn.Module):
|
||||
|
||||
# Re-normalize
|
||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Expand to [num_experts, sequence_length, model_dim]
|
||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||
@ -660,6 +661,7 @@ class DenseMoE(nn.Module):
|
||||
|
||||
# Re-normalize
|
||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Final output tensor
|
||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||
|
@ -3,6 +3,7 @@ import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
@ -36,6 +37,17 @@ class FlashDbrx(FlashCausalLM):
|
||||
else:
|
||||
raise NotImplementedError("FlashDBRX is only available on GPU")
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
except:
|
||||
# FIXME: change back to model id once the tokenizer.json is merged
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||
"Xenova/dbrx-instruct-tokenizer",
|
||||
|
Loading…
Reference in New Issue
Block a user