feat(models): Add DBRX

This commit is contained in:
OlivierDehaene 2024-03-29 18:41:35 +01:00
parent 2c83d09d3b
commit dcfefc425a
3 changed files with 188 additions and 62 deletions

View File

@ -35,6 +35,7 @@ from text_generation_server.utils.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.utils.log import log_once
HAS_MEGABLOCKS = True
try:
@ -176,23 +177,96 @@ def _load_gqa(config, prefix: str, weights):
assert config.d_model % config.n_heads == 0
assert config.n_heads % weights.process_group.size() == 0
weight = weights.get_weights_col_packed_qkv(
prefix=f"{prefix}.Wqkv",
quantize=config.quantize,
)
head_dim = config.d_model // config.n_heads
world_size = weights.process_group.size()
rank = weights.process_group.rank()
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
q_block_size = config.d_model // world_size
q_start = rank * q_block_size
q_stop = (rank + 1) * q_block_size
head_size = config.d_model // config.n_heads
num_heads = config.n_heads // weights.process_group.size()
num_key_value_heads = (
config.attn_config.kv_n_heads // weights.process_group.size()
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
k_offset = config.d_model
k_start = k_offset + rank * kv_block_size
k_stop = k_offset + (rank + 1) * kv_block_size
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
v_start = v_offset + rank * kv_block_size
v_stop = v_offset + (rank + 1) * kv_block_size
if config.quantize in ["gptq", "awq"]:
try:
qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop]
k_qweight = qweight_slice[:, k_start:k_stop]
v_qweight = qweight_slice[:, v_start:v_stop]
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
)
qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
q_qzeros = qzeros_slice[:, q_start:q_stop]
k_qzeros = qzeros_slice[:, k_start:k_stop]
v_qzeros = qzeros_slice[:, v_start:v_stop]
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
scales_slice = weights._get_slice(f"{prefix}.scales")
q_scales = scales_slice[:, q_start:q_stop]
k_scales = scales_slice[:, k_start:k_stop]
v_scales = scales_slice[:, v_start:v_stop]
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
)
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.d_model,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.attn_config.kv_n_heads) * head_size, config.d_model]}"
if config.quantize == "gptq" and quant_method == "gptq":
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
q_g_idx = g_idx_slice[:, q_start:q_stop]
k_g_idx = g_idx_slice[:, k_start:k_stop]
v_g_idx = g_idx_slice[:, v_start:v_stop]
w = [q_g_idx, k_g_idx, v_g_idx]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif config.quantize == "gptq" and quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]
k = qkv_slice[k_start:k_stop]
v = qkv_slice[v_start:v_stop]
weight = torch.cat([q, k, v], dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
@ -200,9 +274,6 @@ def _load_gqa(config, prefix: str, weights):
def _load_experts(config, prefix, weights):
if config.quantize is not None:
raise NotImplementedError("Dbrx does not support weight quantization yet.")
world_size = weights.process_group.size()
rank = weights.process_group.rank()
@ -221,9 +292,9 @@ def _load_experts(config, prefix, weights):
device=weights.device,
)
slice_ = weights._get_slice(f"{prefix}.weight")
slice_ = weights._get_slice(f"{prefix}")
for i in range(config.num_local_experts):
for i in range(config.ffn_config.moe_num_experts):
offset = i * expert_size
expert_slice = slice_[start + offset : stop + offset]
@ -233,6 +304,46 @@ def _load_experts(config, prefix, weights):
return tensor
def _load_experts_quantized(config, prefix, weights, cls):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.ffn_config.ffn_hidden_size % world_size == 0
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
expert_size = config.ffn_config.ffn_hidden_size
block_size = expert_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
slice_ = weights._get_slice(f"{prefix}")
experts = []
for i in range(config.ffn_config.moe_num_experts):
if config.quantize in ["gptq", "awq"]:
raise NotImplementedError(
"Dbrx does not support gptq/awq quantization yet."
)
else:
offset = i * expert_size
expert_slice = (
slice_[start + offset : stop + offset]
.to(dtype=weights.dtype)
.to(device=weights.device)
)
if cls == TensorParallelRowLinear:
expert_slice = expert_slice.t().contiguous()
linear = get_linear(expert_slice, None, config.quantize)
experts.append(cls(linear, weights.process_group))
else:
linear = get_linear(expert_slice, None, config.quantize)
experts.append(cls(linear))
return experts
class DbrxAttention(torch.nn.Module):
def __init__(
self,
@ -391,9 +502,7 @@ class DbrxNormAttentionNorm(nn.Module):
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
return normed_attn_res_output, attn_res
@ -663,6 +772,7 @@ class BlockSparseMoE(nn.Module):
weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
@ -703,8 +813,6 @@ class DenseMoE(nn.Module):
def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__()
raise NotImplementedError("Quantization is not implemented for Dbrx")
self.moe_normalize_expert_weights = (
config.ffn_config.moe_normalize_expert_weights
)
@ -731,24 +839,24 @@ class DenseMoE(nn.Module):
config, f"{prefix}.router.layer", weights, bias=False
)
self.w1 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w3 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w2 = [
TensorParallelRowLinear.load(
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w1 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.w1",
weights=weights,
cls=TensorParallelColumnLinear,
)
self.w2 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.w2",
weights=weights,
cls=TensorParallelRowLinear,
)
self.v1 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.v1",
weights=weights,
cls=TensorParallelColumnLinear,
)
self.process_group = weights.process_group
@ -764,26 +872,30 @@ class DenseMoE(nn.Module):
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
weights,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
weights.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
if self.moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.to(x.dtype)
# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.w3[i](x)
h = self.act(self.w1[i](x)) * self.v1[i](x)
h = self.w2[i](h, reduce=False)
# Add expert output to out with masking
out += h * weights[:, i].view(-1, 1)
@ -821,7 +933,7 @@ class DbrxLayer(nn.Module):
max_s,
):
# Self Attention
attn_output, attn_res = self.self_attn(
attn_output, attn_res = self.attn(
hidden_states,
residual,
cos,
@ -861,9 +973,9 @@ class DbrxModel(torch.nn.Module):
prefix="transformer.norm_f", weights=weights, eps=1e-5
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.head_size = self.layers[0].attn.self_attn.head_size
self.num_heads = self.layers[0].attn.self_attn.num_heads
self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
def forward(
self,
@ -880,7 +992,7 @@ class DbrxModel(torch.nn.Module):
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)

View File

@ -552,6 +552,7 @@ class BlockSparseMoE(nn.Module):
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
@ -660,6 +661,7 @@ class DenseMoE(nn.Module):
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)

View File

@ -3,6 +3,7 @@ import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models import FlashCausalLM
@ -36,16 +37,27 @@ class FlashDbrx(FlashCausalLM):
else:
raise NotImplementedError("FlashDBRX is only available on GPU")
# FIXME: change back to model id once the tokenizer.json is merged
tokenizer = GPT2TokenizerFast.from_pretrained(
"Xenova/dbrx-instruct-tokenizer",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
# FIXME: change back to model id once the tokenizer.json is merged
tokenizer = GPT2TokenizerFast.from_pretrained(
"Xenova/dbrx-instruct-tokenizer",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = DbrxConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code