diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 34d5592e..dd0bcca5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -35,6 +35,7 @@ from text_generation_server.utils.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.utils.log import log_once HAS_MEGABLOCKS = True try: @@ -176,23 +177,96 @@ def _load_gqa(config, prefix: str, weights): assert config.d_model % config.n_heads == 0 assert config.n_heads % weights.process_group.size() == 0 - weight = weights.get_weights_col_packed_qkv( - prefix=f"{prefix}.Wqkv", - quantize=config.quantize, - ) + head_dim = config.d_model // config.n_heads + world_size = weights.process_group.size() + rank = weights.process_group.rank() - if config.quantize not in ["gptq", "awq"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + q_block_size = config.d_model // world_size + q_start = rank * q_block_size + q_stop = (rank + 1) * q_block_size - head_size = config.d_model // config.n_heads - num_heads = config.n_heads // weights.process_group.size() - num_key_value_heads = ( - config.attn_config.kv_n_heads // weights.process_group.size() + kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size + k_offset = config.d_model + k_start = k_offset + rank * kv_block_size + k_stop = k_offset + (rank + 1) * kv_block_size + + v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim + v_start = v_offset + rank * kv_block_size + v_stop = v_offset + (rank + 1) * kv_block_size + + if config.quantize in ["gptq", "awq"]: + try: + qweight_slice = weights._get_slice(f"{prefix}.qweight") + q_qweight = qweight_slice[:, q_start:q_stop] + k_qweight = qweight_slice[:, k_start:k_stop] + v_qweight = qweight_slice[:, v_start:v_stop] + + qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{config.quantize}` weight, make sure the model is already quantized" + ) + + qzeros_slice = weights._get_slice(f"{prefix}.qzeros") + q_qzeros = qzeros_slice[:, q_start:q_stop] + k_qzeros = qzeros_slice[:, k_start:k_stop] + v_qzeros = qzeros_slice[:, v_start:v_stop] + + qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1) + + scales_slice = weights._get_slice(f"{prefix}.scales") + q_scales = scales_slice[:, q_start:q_stop] + k_scales = scales_slice[:, k_start:k_stop] + v_scales = scales_slice[:, v_start:v_stop] + + scales = torch.cat([q_scales, k_scales, v_scales], dim=1) + + bits, groupsize, desc_act, quant_method = weights._get_gptq_params() + + from text_generation_server.utils.layers import HAS_EXLLAMA + + use_exllama = ( + bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act ) - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.d_model, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.attn_config.kv_n_heads) * head_size, config.d_model]}" + + if config.quantize == "gptq" and quant_method == "gptq": + g_idx_slice = weights._get_slice(f"{prefix}.g_idx") + q_g_idx = g_idx_slice[:, q_start:q_stop] + k_g_idx = g_idx_slice[:, k_start:k_stop] + v_g_idx = g_idx_slice[:, v_start:v_stop] + + w = [q_g_idx, k_g_idx, v_g_idx] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif config.quantize == "gptq" and quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) + // groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + else: + qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") + q = qkv_slice[q_start:q_stop] + k = qkv_slice[k_start:k_stop] + v = qkv_slice[v_start:v_stop] + + weight = torch.cat([q, k, v], dim=0) + weight = weight.to(dtype=weights.dtype).to(device=weights.device) return TensorParallelColumnLinear( get_linear(weight, bias=None, quantize=config.quantize) @@ -200,9 +274,6 @@ def _load_gqa(config, prefix: str, weights): def _load_experts(config, prefix, weights): - if config.quantize is not None: - raise NotImplementedError("Dbrx does not support weight quantization yet.") - world_size = weights.process_group.size() rank = weights.process_group.rank() @@ -221,9 +292,9 @@ def _load_experts(config, prefix, weights): device=weights.device, ) - slice_ = weights._get_slice(f"{prefix}.weight") + slice_ = weights._get_slice(f"{prefix}") - for i in range(config.num_local_experts): + for i in range(config.ffn_config.moe_num_experts): offset = i * expert_size expert_slice = slice_[start + offset : stop + offset] @@ -233,6 +304,46 @@ def _load_experts(config, prefix, weights): return tensor +def _load_experts_quantized(config, prefix, weights, cls): + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.ffn_config.ffn_hidden_size % world_size == 0 + ), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards" + + expert_size = config.ffn_config.ffn_hidden_size + block_size = expert_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + slice_ = weights._get_slice(f"{prefix}") + + experts = [] + for i in range(config.ffn_config.moe_num_experts): + if config.quantize in ["gptq", "awq"]: + raise NotImplementedError( + "Dbrx does not support gptq/awq quantization yet." + ) + else: + offset = i * expert_size + expert_slice = ( + slice_[start + offset : stop + offset] + .to(dtype=weights.dtype) + .to(device=weights.device) + ) + + if cls == TensorParallelRowLinear: + expert_slice = expert_slice.t().contiguous() + linear = get_linear(expert_slice, None, config.quantize) + experts.append(cls(linear, weights.process_group)) + else: + linear = get_linear(expert_slice, None, config.quantize) + experts.append(cls(linear)) + + return experts + + class DbrxAttention(torch.nn.Module): def __init__( self, @@ -391,9 +502,7 @@ class DbrxNormAttentionNorm(nn.Module): ) # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) + normed_attn_res_output, attn_res = self.norm_2(attn_output, res) return normed_attn_res_output, attn_res @@ -663,6 +772,7 @@ class BlockSparseMoE(nn.Module): weights = weights / torch.norm( weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True ) + weights = weights.to(x.dtype) # Expand to [num_experts, sequence_length, model_dim] x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) @@ -703,8 +813,6 @@ class DenseMoE(nn.Module): def __init__(self, prefix, config: DbrxConfig, weights): super().__init__() - raise NotImplementedError("Quantization is not implemented for Dbrx") - self.moe_normalize_expert_weights = ( config.ffn_config.moe_normalize_expert_weights ) @@ -731,24 +839,24 @@ class DenseMoE(nn.Module): config, f"{prefix}.router.layer", weights, bias=False ) - self.w1 = [ - TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False - ) - for i in range(self.num_experts) - ] - self.w3 = [ - TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False - ) - for i in range(self.num_experts) - ] - self.w2 = [ - TensorParallelRowLinear.load( - config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False - ) - for i in range(self.num_experts) - ] + self.w1 = _load_experts_quantized( + config, + prefix=f"{prefix}.experts.mlp.w1", + weights=weights, + cls=TensorParallelColumnLinear, + ) + self.w2 = _load_experts_quantized( + config, + prefix=f"{prefix}.experts.mlp.w2", + weights=weights, + cls=TensorParallelRowLinear, + ) + self.v1 = _load_experts_quantized( + config, + prefix=f"{prefix}.experts.mlp.v1", + weights=weights, + cls=TensorParallelColumnLinear, + ) self.process_group = weights.process_group @@ -764,26 +872,30 @@ class DenseMoE(nn.Module): # gate_logits: (sequence_length, n_experts) gate_logits = self.gate(x) # all_probs: (sequence_length, n_experts) and upcast for softmax - all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) if self.top_k < self.num_experts: _, not_selected_experts = torch.topk( - all_probs, + weights, self.num_experts - self.top_k, largest=False, sorted=False, dim=1, ) # Mask not selected experts - all_probs.scatter_(1, not_selected_experts, 0) + weights.scatter_(1, not_selected_experts, 0) # Re-normalize - weights = all_probs / all_probs.sum(dim=1, keepdim=True) + if self.moe_normalize_expert_weights: + weights = weights / torch.norm( + weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True + ) + weights = weights.to(x.dtype) # Final output tensor out = x.new_zeros(x.shape[0], self.hidden_dim) for i in range(self.num_experts): - h = self.act(self.w1[i](x)) * self.w3[i](x) + h = self.act(self.w1[i](x)) * self.v1[i](x) h = self.w2[i](h, reduce=False) # Add expert output to out with masking out += h * weights[:, i].view(-1, 1) @@ -821,7 +933,7 @@ class DbrxLayer(nn.Module): max_s, ): # Self Attention - attn_output, attn_res = self.self_attn( + attn_output, attn_res = self.attn( hidden_states, residual, cos, @@ -861,9 +973,9 @@ class DbrxModel(torch.nn.Module): prefix="transformer.norm_f", weights=weights, eps=1e-5 ) - self.head_size = self.layers[0].self_attn.head_size - self.num_heads = self.layers[0].self_attn.num_heads - self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + self.head_size = self.layers[0].attn.self_attn.head_size + self.num_heads = self.layers[0].attn.self_attn.num_heads + self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads def forward( self, @@ -880,7 +992,7 @@ class DbrxModel(torch.nn.Module): # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 17d4f708..d71a3f0c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -552,6 +552,7 @@ class BlockSparseMoE(nn.Module): # Re-normalize weights = all_probs / all_probs.sum(dim=1, keepdim=True) + weights = weights.to(x.dtype) # Expand to [num_experts, sequence_length, model_dim] x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) @@ -660,6 +661,7 @@ class DenseMoE(nn.Module): # Re-normalize weights = all_probs / all_probs.sum(dim=1, keepdim=True) + weights = weights.to(x.dtype) # Final output tensor out = x.new_zeros(x.shape[0], self.hidden_dim) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index d394cbba..b5411e22 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -3,6 +3,7 @@ import torch.distributed from opentelemetry import trace from typing import Optional +from transformers import AutoTokenizer from transformers.models.gpt2 import GPT2TokenizerFast from text_generation_server.models import FlashCausalLM @@ -36,16 +37,27 @@ class FlashDbrx(FlashCausalLM): else: raise NotImplementedError("FlashDBRX is only available on GPU") - # FIXME: change back to model id once the tokenizer.json is merged - tokenizer = GPT2TokenizerFast.from_pretrained( - "Xenova/dbrx-instruct-tokenizer", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + use_fast=True, + from_slow=False, + ) + except: + # FIXME: change back to model id once the tokenizer.json is merged + tokenizer = GPT2TokenizerFast.from_pretrained( + "Xenova/dbrx-instruct-tokenizer", + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + use_fast=True, + from_slow=False, + ) config = DbrxConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code