From 7d31cb6e754f47a1053f530e9fd02812c92c01d0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 23 Apr 2024 08:50:18 +0000 Subject: [PATCH] Phi3 support. --- .../text_generation_server/models/__init__.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 06792b0d..e4e8717d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -327,7 +327,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "llama" or model_type == "baichuan": + elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3": if FLASH_ATTENTION: return FlashLlama( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4cf0fcf2..953eb8a1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -101,6 +101,13 @@ def load_attention(config, prefix, weights): weights=weights, bias=False, ) + elif config.model_type == "phi3": + return TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.qkv_proj", + weights=weights, + bias=False, + ) else: return TensorParallelColumnLinear.load_multi( config, @@ -257,13 +264,21 @@ class LlamaMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], - weights=weights, - dim=0, - bias=False, - ) + if config.model_type == "phi3": + self.gate_up_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.gate_up_proj", + weights=weights, + bias=False, + ) + else: + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj",