Phi3 support.

This commit is contained in:
Nicolas Patry 2024-04-23 08:50:18 +00:00
parent ed72e92126
commit 7d31cb6e75
2 changed files with 23 additions and 8 deletions

View File

@ -327,7 +327,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "llama" or model_type == "baichuan": elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashLlama(
model_id, model_id,

View File

@ -101,6 +101,13 @@ def load_attention(config, prefix, weights):
weights=weights, weights=weights,
bias=False, bias=False,
) )
elif config.model_type == "phi3":
return TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=False,
)
else: else:
return TensorParallelColumnLinear.load_multi( return TensorParallelColumnLinear.load_multi(
config, config,
@ -257,13 +264,21 @@ class LlamaMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( if config.model_type == "phi3":
config, self.gate_up_proj = TensorParallelColumnLinear.load(
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], config,
weights=weights, prefix=f"{prefix}.gate_up_proj",
dim=0, weights=weights,
bias=False, bias=False,
) )
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load( self.down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",