mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Phi3 support.
This commit is contained in:
parent
ed72e92126
commit
7d31cb6e75
@ -327,7 +327,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "llama" or model_type == "baichuan":
|
elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashLlama(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -101,6 +101,13 @@ def load_attention(config, prefix, weights):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
elif config.model_type == "phi3":
|
||||||
|
return TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
@ -257,6 +264,14 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
|
if config.model_type == "phi3":
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
Loading…
Reference in New Issue
Block a user