feat: adjust attn weight loading logic

This commit is contained in:
drbh 2024-05-29 15:05:57 +00:00
parent 612bc483b6
commit 3cf4354944

View File

@ -49,37 +49,31 @@ if SYSTEM == "rocm":
def load_attention(config, prefix, weights):
bias = config.attention_bias
if config.num_attention_heads != config.num_key_value_heads:
return TensorParallelColumnLinear.load_multi(
# if specific model type, load the correct attention
if config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
)
else:
if config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=bias,
)
elif config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
class FlashLlamaAttention(torch.nn.Module):