diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index dbe49039..f72864a2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -136,6 +136,11 @@ class ModelType(enum.Enum): "name": "Phi 3", "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", } + PHI3SMALL = { + "type": "phi3small", + "name": "Phi 3 Small", + "url": "https://huggingface.co/microsoft/Phi-3-small-8k-instruct", + } GEMMA = { "type": "gemma", "name": "Gemma", @@ -579,7 +584,12 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: + elif ( + model_type == LLAMA + or model_type == BAICHUAN + or model_type == PHI3 + or model_type == PHI3SMALL + ): if FLASH_ATTENTION: return FlashLlama( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index cef712f0..29f1b147 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -70,6 +70,13 @@ def load_attention(config, prefix, weights): weights=weights, bias=bias, ) + elif config.model_type == "phi3small": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=bias, + ) # otherwise, load the default attention based on the number of heads return TensorParallelColumnLinear.load_multi( @@ -93,12 +100,20 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + if config.model_type == "phi3small": + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_embedding_base, + device=weights.device, + ) + else: + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) self.softmax_scale = self.head_size**-0.5 @@ -114,12 +129,21 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) - self.o_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ) + if config.model_type == "phi3small": + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.dense", + weights=weights, + bias=False, + ) + else: + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -209,6 +233,13 @@ class LlamaMLP(nn.Module): weights=weights, bias=bias, ) + elif config.model_type == "phi3small": + self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( + config, + prefix=f"{prefix}.up_proj", + weights=weights, + bias=bias, + ) else: self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, @@ -259,13 +290,16 @@ class FlashLlamaLayer(nn.Module): ) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + if config.model_type == "phi3small": + eps = config.layer_norm_epsilon + else: + eps = config.rms_norm_eps + self.input_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + prefix=f"{prefix}.input_layernorm", weights=weights, eps=eps ) self.post_attention_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, + prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=eps ) def forward( @@ -327,11 +361,19 @@ class FlashLlamaModel(torch.nn.Module): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.model.norm", - weights=weights, - eps=config.rms_norm_eps, - ) + + if config.model_type == "phi3small": + self.norm = FastRMSNorm.load( + prefix="model.final_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + else: + self.norm = FastRMSNorm.load( + prefix="model.norm" if not prefix else f"{prefix}.model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) self.gradient_checkpointing = False