From 14d83c176fe845ffbdb8a4ffc5e613df39054f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 18 Jun 2024 10:55:43 +0200 Subject: [PATCH] Support exl2-quantized Qwen2 models Fixes #2081. --- .../custom_modeling/flash_qwen2_modeling.py | 27 +++---------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index b1de58b2..df5a8ae9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -40,31 +40,12 @@ def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 - weight = weights.get_multi_weights_col( + return TensorParallelColumnLinear.load_multi( + config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, - ) - - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - head_size = config.hidden_size // config.num_attention_heads - num_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - - w = [ - weights.get_sharded(f"{p}.bias", dim=0) - for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] - ] - bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) - - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) + weights=weights, + bias=True, )