diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 7e7ddef1..679e1e2f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -242,7 +242,13 @@ def get_model( use_medusa=use_medusa, ) else: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Phi")) + return CausalLM( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif model_type == "phi-msft": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 9f33143f..d103973f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -25,34 +25,24 @@ class PhiConfig(PretrainedConfig): num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, - hidden_act="gelu_fast", # llama uses silu - max_position_embeddings=2048, - initializer_range=0.02, - layer_norm_eps=1e-05, # rms in llama - use_cache=True, + hidden_act="gelu_fast", # llama uses silu + layer_norm_eps=1e-05, # rms in llama, pad_token_id=0, bos_token_id=1, eos_token_id=2, - pretraining_tp=1, tie_word_embeddings=False, - rope_scaling=None, rope_theta=10000.0, - resid_pdrop=0.1, # llama doesn't have this - partial_rotary_factor=0.5, + resid_pdrop=0.1, # llama doesn't have this + partial_rotary_factor=0.5, # important difference between llama and phi **kwargs, ): self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act - self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling self.rope_theta = rope_theta self.resid_pdrop = resid_pdrop self.partial_rotary_factor = partial_rotary_factor @@ -116,16 +106,16 @@ class FlashPhiAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads + self.softmax_scale = self.head_size**-0.5 + self.rotary_dim = int(config.partial_rotary_factor * self.head_size) + self.rotary_emb = PositionRotaryEmbedding.static( config=config, - dim=self.num_heads, + dim=self.rotary_dim, base=config.rope_theta, device=weights.device, ) - self.softmax_scale = self.head_size**-0.5 - self.rotary_dim = int(config.partial_rotary_factor * self.head_size) - if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "