From cbced7f0f9ca0b62216223859b82a2632d1c7a1f Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 29 May 2024 12:42:11 -0400 Subject: [PATCH] feat: adjust attn weight loading logic (#1975) This PR updates `load_attention` to prefer loading specific attention based on the model type. Additionally there were two cases where `TensorParallelColumnLinear.load_multi` was called and this reduces it to a single path --- .../custom_modeling/flash_llama_modeling.py | 48 ++++++++----------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6e23aa2b..f722bf73 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -49,37 +49,31 @@ if SYSTEM == "rocm": def load_attention(config, prefix, weights): bias = config.attention_bias - if config.num_attention_heads != config.num_key_value_heads: - return TensorParallelColumnLinear.load_multi( + + # if specific model type, load the correct attention + if config.model_type == "phi3": + return TensorParallelColumnLinear.load_qkv( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, + prefix=f"{prefix}.qkv_proj", weights=weights, bias=bias, ) - else: - if config.model_type == "baichuan": - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.W_pack", - weights=weights, - bias=bias, - ) - elif config.model_type == "phi3": - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.qkv_proj", - weights=weights, - bias=bias, - ) - else: - return TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=bias, - ) + elif config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=bias, + ) + + # otherwise, load the default attention based on the number of heads + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=bias, + ) class FlashLlamaAttention(torch.nn.Module):