mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat(models): Add DBRX
This commit is contained in:
parent
2c83d09d3b
commit
dcfefc425a
@ -35,6 +35,7 @@ from text_generation_server.utils.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
HAS_MEGABLOCKS = True
|
HAS_MEGABLOCKS = True
|
||||||
try:
|
try:
|
||||||
@ -176,33 +177,103 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
assert config.d_model % config.n_heads == 0
|
assert config.d_model % config.n_heads == 0
|
||||||
assert config.n_heads % weights.process_group.size() == 0
|
assert config.n_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
weight = weights.get_weights_col_packed_qkv(
|
head_dim = config.d_model // config.n_heads
|
||||||
prefix=f"{prefix}.Wqkv",
|
world_size = weights.process_group.size()
|
||||||
quantize=config.quantize,
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
q_block_size = config.d_model // world_size
|
||||||
|
q_start = rank * q_block_size
|
||||||
|
q_stop = (rank + 1) * q_block_size
|
||||||
|
|
||||||
|
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
|
||||||
|
k_offset = config.d_model
|
||||||
|
k_start = k_offset + rank * kv_block_size
|
||||||
|
k_stop = k_offset + (rank + 1) * kv_block_size
|
||||||
|
|
||||||
|
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
|
||||||
|
v_start = v_offset + rank * kv_block_size
|
||||||
|
v_stop = v_offset + (rank + 1) * kv_block_size
|
||||||
|
|
||||||
|
if config.quantize in ["gptq", "awq"]:
|
||||||
|
try:
|
||||||
|
qweight_slice = weights._get_slice(f"{prefix}.qweight")
|
||||||
|
q_qweight = qweight_slice[:, q_start:q_stop]
|
||||||
|
k_qweight = qweight_slice[:, k_start:k_stop]
|
||||||
|
v_qweight = qweight_slice[:, v_start:v_stop]
|
||||||
|
|
||||||
|
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq"]:
|
qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
|
||||||
|
q_qzeros = qzeros_slice[:, q_start:q_stop]
|
||||||
|
k_qzeros = qzeros_slice[:, k_start:k_stop]
|
||||||
|
v_qzeros = qzeros_slice[:, v_start:v_stop]
|
||||||
|
|
||||||
|
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
|
||||||
|
|
||||||
|
scales_slice = weights._get_slice(f"{prefix}.scales")
|
||||||
|
q_scales = scales_slice[:, q_start:q_stop]
|
||||||
|
k_scales = scales_slice[:, k_start:k_stop]
|
||||||
|
v_scales = scales_slice[:, v_start:v_stop]
|
||||||
|
|
||||||
|
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
|
||||||
|
|
||||||
|
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
|
use_exllama = (
|
||||||
|
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.quantize == "gptq" and quant_method == "gptq":
|
||||||
|
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
|
||||||
|
q_g_idx = g_idx_slice[:, q_start:q_stop]
|
||||||
|
k_g_idx = g_idx_slice[:, k_start:k_stop]
|
||||||
|
v_g_idx = g_idx_slice[:, v_start:v_stop]
|
||||||
|
|
||||||
|
w = [q_g_idx, k_g_idx, v_g_idx]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
elif config.quantize == "gptq" and quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
|
||||||
|
// groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
|
else:
|
||||||
|
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
|
||||||
|
q = qkv_slice[q_start:q_stop]
|
||||||
|
k = qkv_slice[k_start:k_stop]
|
||||||
|
v = qkv_slice[v_start:v_stop]
|
||||||
|
|
||||||
|
weight = torch.cat([q, k, v], dim=0)
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.d_model // config.n_heads
|
|
||||||
num_heads = config.n_heads // weights.process_group.size()
|
|
||||||
num_key_value_heads = (
|
|
||||||
config.attn_config.kv_n_heads // weights.process_group.size()
|
|
||||||
)
|
|
||||||
assert list(weight.shape) == [
|
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
|
||||||
config.d_model,
|
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.attn_config.kv_n_heads) * head_size, config.d_model]}"
|
|
||||||
|
|
||||||
return TensorParallelColumnLinear(
|
return TensorParallelColumnLinear(
|
||||||
get_linear(weight, bias=None, quantize=config.quantize)
|
get_linear(weight, bias=None, quantize=config.quantize)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_experts(config, prefix, weights):
|
def _load_experts(config, prefix, weights):
|
||||||
if config.quantize is not None:
|
|
||||||
raise NotImplementedError("Dbrx does not support weight quantization yet.")
|
|
||||||
|
|
||||||
world_size = weights.process_group.size()
|
world_size = weights.process_group.size()
|
||||||
rank = weights.process_group.rank()
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
@ -221,9 +292,9 @@ def _load_experts(config, prefix, weights):
|
|||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
slice_ = weights._get_slice(f"{prefix}.weight")
|
slice_ = weights._get_slice(f"{prefix}")
|
||||||
|
|
||||||
for i in range(config.num_local_experts):
|
for i in range(config.ffn_config.moe_num_experts):
|
||||||
offset = i * expert_size
|
offset = i * expert_size
|
||||||
expert_slice = slice_[start + offset : stop + offset]
|
expert_slice = slice_[start + offset : stop + offset]
|
||||||
|
|
||||||
@ -233,6 +304,46 @@ def _load_experts(config, prefix, weights):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _load_experts_quantized(config, prefix, weights, cls):
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
config.ffn_config.ffn_hidden_size % world_size == 0
|
||||||
|
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
|
expert_size = config.ffn_config.ffn_hidden_size
|
||||||
|
block_size = expert_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}")
|
||||||
|
|
||||||
|
experts = []
|
||||||
|
for i in range(config.ffn_config.moe_num_experts):
|
||||||
|
if config.quantize in ["gptq", "awq"]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Dbrx does not support gptq/awq quantization yet."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
offset = i * expert_size
|
||||||
|
expert_slice = (
|
||||||
|
slice_[start + offset : stop + offset]
|
||||||
|
.to(dtype=weights.dtype)
|
||||||
|
.to(device=weights.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if cls == TensorParallelRowLinear:
|
||||||
|
expert_slice = expert_slice.t().contiguous()
|
||||||
|
linear = get_linear(expert_slice, None, config.quantize)
|
||||||
|
experts.append(cls(linear, weights.process_group))
|
||||||
|
else:
|
||||||
|
linear = get_linear(expert_slice, None, config.quantize)
|
||||||
|
experts.append(cls(linear))
|
||||||
|
|
||||||
|
return experts
|
||||||
|
|
||||||
|
|
||||||
class DbrxAttention(torch.nn.Module):
|
class DbrxAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -391,9 +502,7 @@ class DbrxNormAttentionNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
|
||||||
attn_output, res
|
|
||||||
)
|
|
||||||
|
|
||||||
return normed_attn_res_output, attn_res
|
return normed_attn_res_output, attn_res
|
||||||
|
|
||||||
@ -663,6 +772,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
weights = weights / torch.norm(
|
weights = weights / torch.norm(
|
||||||
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||||
)
|
)
|
||||||
|
weights = weights.to(x.dtype)
|
||||||
|
|
||||||
# Expand to [num_experts, sequence_length, model_dim]
|
# Expand to [num_experts, sequence_length, model_dim]
|
||||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||||
@ -703,8 +813,6 @@ class DenseMoE(nn.Module):
|
|||||||
def __init__(self, prefix, config: DbrxConfig, weights):
|
def __init__(self, prefix, config: DbrxConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
raise NotImplementedError("Quantization is not implemented for Dbrx")
|
|
||||||
|
|
||||||
self.moe_normalize_expert_weights = (
|
self.moe_normalize_expert_weights = (
|
||||||
config.ffn_config.moe_normalize_expert_weights
|
config.ffn_config.moe_normalize_expert_weights
|
||||||
)
|
)
|
||||||
@ -731,24 +839,24 @@ class DenseMoE(nn.Module):
|
|||||||
config, f"{prefix}.router.layer", weights, bias=False
|
config, f"{prefix}.router.layer", weights, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.w1 = [
|
self.w1 = _load_experts_quantized(
|
||||||
TensorParallelColumnLinear.load(
|
config,
|
||||||
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
|
prefix=f"{prefix}.experts.mlp.w1",
|
||||||
|
weights=weights,
|
||||||
|
cls=TensorParallelColumnLinear,
|
||||||
)
|
)
|
||||||
for i in range(self.num_experts)
|
self.w2 = _load_experts_quantized(
|
||||||
]
|
config,
|
||||||
self.w3 = [
|
prefix=f"{prefix}.experts.mlp.w2",
|
||||||
TensorParallelColumnLinear.load(
|
weights=weights,
|
||||||
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
|
cls=TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
for i in range(self.num_experts)
|
self.v1 = _load_experts_quantized(
|
||||||
]
|
config,
|
||||||
self.w2 = [
|
prefix=f"{prefix}.experts.mlp.v1",
|
||||||
TensorParallelRowLinear.load(
|
weights=weights,
|
||||||
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
|
cls=TensorParallelColumnLinear,
|
||||||
)
|
)
|
||||||
for i in range(self.num_experts)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
@ -764,26 +872,30 @@ class DenseMoE(nn.Module):
|
|||||||
# gate_logits: (sequence_length, n_experts)
|
# gate_logits: (sequence_length, n_experts)
|
||||||
gate_logits = self.gate(x)
|
gate_logits = self.gate(x)
|
||||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
|
||||||
if self.top_k < self.num_experts:
|
if self.top_k < self.num_experts:
|
||||||
_, not_selected_experts = torch.topk(
|
_, not_selected_experts = torch.topk(
|
||||||
all_probs,
|
weights,
|
||||||
self.num_experts - self.top_k,
|
self.num_experts - self.top_k,
|
||||||
largest=False,
|
largest=False,
|
||||||
sorted=False,
|
sorted=False,
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
# Mask not selected experts
|
# Mask not selected experts
|
||||||
all_probs.scatter_(1, not_selected_experts, 0)
|
weights.scatter_(1, not_selected_experts, 0)
|
||||||
|
|
||||||
# Re-normalize
|
# Re-normalize
|
||||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
if self.moe_normalize_expert_weights:
|
||||||
|
weights = weights / torch.norm(
|
||||||
|
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||||
|
)
|
||||||
|
weights = weights.to(x.dtype)
|
||||||
|
|
||||||
# Final output tensor
|
# Final output tensor
|
||||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||||
for i in range(self.num_experts):
|
for i in range(self.num_experts):
|
||||||
h = self.act(self.w1[i](x)) * self.w3[i](x)
|
h = self.act(self.w1[i](x)) * self.v1[i](x)
|
||||||
h = self.w2[i](h, reduce=False)
|
h = self.w2[i](h, reduce=False)
|
||||||
# Add expert output to out with masking
|
# Add expert output to out with masking
|
||||||
out += h * weights[:, i].view(-1, 1)
|
out += h * weights[:, i].view(-1, 1)
|
||||||
@ -821,7 +933,7 @@ class DbrxLayer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output, attn_res = self.self_attn(
|
attn_output, attn_res = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
@ -861,9 +973,9 @@ class DbrxModel(torch.nn.Module):
|
|||||||
prefix="transformer.norm_f", weights=weights, eps=1e-5
|
prefix="transformer.norm_f", weights=weights, eps=1e-5
|
||||||
)
|
)
|
||||||
|
|
||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].attn.self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].attn.self_attn.num_heads
|
||||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -880,7 +992,7 @@ class DbrxModel(torch.nn.Module):
|
|||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -552,6 +552,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
|
|
||||||
# Re-normalize
|
# Re-normalize
|
||||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||||
|
weights = weights.to(x.dtype)
|
||||||
|
|
||||||
# Expand to [num_experts, sequence_length, model_dim]
|
# Expand to [num_experts, sequence_length, model_dim]
|
||||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||||
@ -660,6 +661,7 @@ class DenseMoE(nn.Module):
|
|||||||
|
|
||||||
# Re-normalize
|
# Re-normalize
|
||||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||||
|
weights = weights.to(x.dtype)
|
||||||
|
|
||||||
# Final output tensor
|
# Final output tensor
|
||||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||||
|
@ -3,6 +3,7 @@ import torch.distributed
|
|||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from transformers import AutoTokenizer
|
||||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
@ -36,6 +37,17 @@ class FlashDbrx(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashDBRX is only available on GPU")
|
raise NotImplementedError("FlashDBRX is only available on GPU")
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
use_fast=True,
|
||||||
|
from_slow=False,
|
||||||
|
)
|
||||||
|
except:
|
||||||
# FIXME: change back to model id once the tokenizer.json is merged
|
# FIXME: change back to model id once the tokenizer.json is merged
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||||
"Xenova/dbrx-instruct-tokenizer",
|
"Xenova/dbrx-instruct-tokenizer",
|
||||||
|
Loading…
Reference in New Issue
Block a user