feat(models): Add DBRX

This commit is contained in:
OlivierDehaene 2024-03-29 18:41:35 +01:00
parent 2c83d09d3b
commit dcfefc425a
3 changed files with 188 additions and 62 deletions

View File

@ -35,6 +35,7 @@ from text_generation_server.utils.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.utils.log import log_once
HAS_MEGABLOCKS = True HAS_MEGABLOCKS = True
try: try:
@ -176,33 +177,103 @@ def _load_gqa(config, prefix: str, weights):
assert config.d_model % config.n_heads == 0 assert config.d_model % config.n_heads == 0
assert config.n_heads % weights.process_group.size() == 0 assert config.n_heads % weights.process_group.size() == 0
weight = weights.get_weights_col_packed_qkv( head_dim = config.d_model // config.n_heads
prefix=f"{prefix}.Wqkv", world_size = weights.process_group.size()
quantize=config.quantize, rank = weights.process_group.rank()
q_block_size = config.d_model // world_size
q_start = rank * q_block_size
q_stop = (rank + 1) * q_block_size
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
k_offset = config.d_model
k_start = k_offset + rank * kv_block_size
k_stop = k_offset + (rank + 1) * kv_block_size
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
v_start = v_offset + rank * kv_block_size
v_stop = v_offset + (rank + 1) * kv_block_size
if config.quantize in ["gptq", "awq"]:
try:
qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop]
k_qweight = qweight_slice[:, k_start:k_stop]
v_qweight = qweight_slice[:, v_start:v_stop]
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
) )
if config.quantize not in ["gptq", "awq"]: qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
q_qzeros = qzeros_slice[:, q_start:q_stop]
k_qzeros = qzeros_slice[:, k_start:k_stop]
v_qzeros = qzeros_slice[:, v_start:v_stop]
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
scales_slice = weights._get_slice(f"{prefix}.scales")
q_scales = scales_slice[:, q_start:q_stop]
k_scales = scales_slice[:, k_start:k_stop]
v_scales = scales_slice[:, v_start:v_stop]
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
)
if config.quantize == "gptq" and quant_method == "gptq":
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
q_g_idx = g_idx_slice[:, q_start:q_stop]
k_g_idx = g_idx_slice[:, k_start:k_stop]
v_g_idx = g_idx_slice[:, v_start:v_stop]
w = [q_g_idx, k_g_idx, v_g_idx]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif config.quantize == "gptq" and quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]
k = qkv_slice[k_start:k_stop]
v = qkv_slice[v_start:v_stop]
weight = torch.cat([q, k, v], dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.d_model // config.n_heads
num_heads = config.n_heads // weights.process_group.size()
num_key_value_heads = (
config.attn_config.kv_n_heads // weights.process_group.size()
)
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.d_model,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.attn_config.kv_n_heads) * head_size, config.d_model]}"
return TensorParallelColumnLinear( return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize) get_linear(weight, bias=None, quantize=config.quantize)
) )
def _load_experts(config, prefix, weights): def _load_experts(config, prefix, weights):
if config.quantize is not None:
raise NotImplementedError("Dbrx does not support weight quantization yet.")
world_size = weights.process_group.size() world_size = weights.process_group.size()
rank = weights.process_group.rank() rank = weights.process_group.rank()
@ -221,9 +292,9 @@ def _load_experts(config, prefix, weights):
device=weights.device, device=weights.device,
) )
slice_ = weights._get_slice(f"{prefix}.weight") slice_ = weights._get_slice(f"{prefix}")
for i in range(config.num_local_experts): for i in range(config.ffn_config.moe_num_experts):
offset = i * expert_size offset = i * expert_size
expert_slice = slice_[start + offset : stop + offset] expert_slice = slice_[start + offset : stop + offset]
@ -233,6 +304,46 @@ def _load_experts(config, prefix, weights):
return tensor return tensor
def _load_experts_quantized(config, prefix, weights, cls):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.ffn_config.ffn_hidden_size % world_size == 0
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
expert_size = config.ffn_config.ffn_hidden_size
block_size = expert_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
slice_ = weights._get_slice(f"{prefix}")
experts = []
for i in range(config.ffn_config.moe_num_experts):
if config.quantize in ["gptq", "awq"]:
raise NotImplementedError(
"Dbrx does not support gptq/awq quantization yet."
)
else:
offset = i * expert_size
expert_slice = (
slice_[start + offset : stop + offset]
.to(dtype=weights.dtype)
.to(device=weights.device)
)
if cls == TensorParallelRowLinear:
expert_slice = expert_slice.t().contiguous()
linear = get_linear(expert_slice, None, config.quantize)
experts.append(cls(linear, weights.process_group))
else:
linear = get_linear(expert_slice, None, config.quantize)
experts.append(cls(linear))
return experts
class DbrxAttention(torch.nn.Module): class DbrxAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -391,9 +502,7 @@ class DbrxNormAttentionNorm(nn.Module):
) )
# faster post attention rms norm # faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm( normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
attn_output, res
)
return normed_attn_res_output, attn_res return normed_attn_res_output, attn_res
@ -663,6 +772,7 @@ class BlockSparseMoE(nn.Module):
weights = weights / torch.norm( weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
) )
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim] # Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
@ -703,8 +813,6 @@ class DenseMoE(nn.Module):
def __init__(self, prefix, config: DbrxConfig, weights): def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__() super().__init__()
raise NotImplementedError("Quantization is not implemented for Dbrx")
self.moe_normalize_expert_weights = ( self.moe_normalize_expert_weights = (
config.ffn_config.moe_normalize_expert_weights config.ffn_config.moe_normalize_expert_weights
) )
@ -731,24 +839,24 @@ class DenseMoE(nn.Module):
config, f"{prefix}.router.layer", weights, bias=False config, f"{prefix}.router.layer", weights, bias=False
) )
self.w1 = [ self.w1 = _load_experts_quantized(
TensorParallelColumnLinear.load( config,
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False prefix=f"{prefix}.experts.mlp.w1",
weights=weights,
cls=TensorParallelColumnLinear,
) )
for i in range(self.num_experts) self.w2 = _load_experts_quantized(
] config,
self.w3 = [ prefix=f"{prefix}.experts.mlp.w2",
TensorParallelColumnLinear.load( weights=weights,
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False cls=TensorParallelRowLinear,
) )
for i in range(self.num_experts) self.v1 = _load_experts_quantized(
] config,
self.w2 = [ prefix=f"{prefix}.experts.mlp.v1",
TensorParallelRowLinear.load( weights=weights,
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False cls=TensorParallelColumnLinear,
) )
for i in range(self.num_experts)
]
self.process_group = weights.process_group self.process_group = weights.process_group
@ -764,26 +872,30 @@ class DenseMoE(nn.Module):
# gate_logits: (sequence_length, n_experts) # gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x) gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax # all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts: if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk( _, not_selected_experts = torch.topk(
all_probs, weights,
self.num_experts - self.top_k, self.num_experts - self.top_k,
largest=False, largest=False,
sorted=False, sorted=False,
dim=1, dim=1,
) )
# Mask not selected experts # Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0) weights.scatter_(1, not_selected_experts, 0)
# Re-normalize # Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True) if self.moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.to(x.dtype)
# Final output tensor # Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim) out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts): for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.w3[i](x) h = self.act(self.w1[i](x)) * self.v1[i](x)
h = self.w2[i](h, reduce=False) h = self.w2[i](h, reduce=False)
# Add expert output to out with masking # Add expert output to out with masking
out += h * weights[:, i].view(-1, 1) out += h * weights[:, i].view(-1, 1)
@ -821,7 +933,7 @@ class DbrxLayer(nn.Module):
max_s, max_s,
): ):
# Self Attention # Self Attention
attn_output, attn_res = self.self_attn( attn_output, attn_res = self.attn(
hidden_states, hidden_states,
residual, residual,
cos, cos,
@ -861,9 +973,9 @@ class DbrxModel(torch.nn.Module):
prefix="transformer.norm_f", weights=weights, eps=1e-5 prefix="transformer.norm_f", weights=weights, eps=1e-5
) )
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].attn.self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].attn.self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
def forward( def forward(
self, self,
@ -880,7 +992,7 @@ class DbrxModel(torch.nn.Module):
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype position_ids, max_s, hidden_states.dtype
) )

View File

@ -552,6 +552,7 @@ class BlockSparseMoE(nn.Module):
# Re-normalize # Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True) weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim] # Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
@ -660,6 +661,7 @@ class DenseMoE(nn.Module):
# Re-normalize # Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True) weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Final output tensor # Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim) out = x.new_zeros(x.shape[0], self.hidden_dim)

View File

@ -3,6 +3,7 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
@ -36,6 +37,17 @@ class FlashDbrx(FlashCausalLM):
else: else:
raise NotImplementedError("FlashDBRX is only available on GPU") raise NotImplementedError("FlashDBRX is only available on GPU")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
# FIXME: change back to model id once the tokenizer.json is merged # FIXME: change back to model id once the tokenizer.json is merged
tokenizer = GPT2TokenizerFast.from_pretrained( tokenizer = GPT2TokenizerFast.from_pretrained(
"Xenova/dbrx-instruct-tokenizer", "Xenova/dbrx-instruct-tokenizer",