mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
feat: adjust attn weight loading logic (#1975)
This PR updates `load_attention` to prefer loading specific attention based on the model type. Additionally there were two cases where `TensorParallelColumnLinear.load_multi` was called and this reduces it to a single path
This commit is contained in:
parent
2b204f0479
commit
4dca35fc62
@ -49,30 +49,24 @@ if SYSTEM == "rocm":
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
bias = config.attention_bias
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
if config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
elif config.model_type == "phi3":
|
||||
|
||||
# if specific model type, load the correct attention
|
||||
if config.model_type == "phi3":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
elif config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
|
Loading…
Reference in New Issue
Block a user