mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
feat: adjust attn weight loading logic (#1975)
This PR updates `load_attention` to prefer loading specific attention based on the model type. Additionally there were two cases where `TensorParallelColumnLinear.load_multi` was called and this reduces it to a single path
This commit is contained in:
parent
612bc483b6
commit
cbced7f0f9
@ -49,37 +49,31 @@ if SYSTEM == "rocm":
|
|||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
bias = config.attention_bias
|
bias = config.attention_bias
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
|
||||||
return TensorParallelColumnLinear.load_multi(
|
# if specific model type, load the correct attention
|
||||||
|
if config.model_type == "phi3":
|
||||||
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefix=f"{prefix}.qkv_proj",
|
||||||
dim=0,
|
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
else:
|
elif config.model_type == "baichuan":
|
||||||
if config.model_type == "baichuan":
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
config,
|
||||||
config,
|
prefix=f"{prefix}.W_pack",
|
||||||
prefix=f"{prefix}.W_pack",
|
weights=weights,
|
||||||
weights=weights,
|
bias=bias,
|
||||||
bias=bias,
|
)
|
||||||
)
|
|
||||||
elif config.model_type == "phi3":
|
# otherwise, load the default attention based on the number of heads
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
weights=weights,
|
dim=0,
|
||||||
bias=bias,
|
weights=weights,
|
||||||
)
|
bias=bias,
|
||||||
else:
|
)
|
||||||
return TensorParallelColumnLinear.load_multi(
|
|
||||||
config,
|
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
||||||
dim=0,
|
|
||||||
weights=weights,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user