From c49332adb6f02fde8dde9bda98f674076ccaa1ac Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 23 Jan 2024 00:18:29 +0000 Subject: [PATCH] fix: remove unused imports and duplicate spaces --- .../models/custom_modeling/flash_phi_modeling.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index b49d0985..eb98a756 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -14,9 +14,7 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, TensorParallelHead, get_linear, - FastRMSNorm, FastLayerNorm, - FastLinear, ) class PhiConfig(PretrainedConfig): @@ -65,10 +63,8 @@ class PhiConfig(PretrainedConfig): **kwargs, ) - def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: - # should never get here return _load_gqa(config, prefix, weights) else: if config.model_type == "baichuan": @@ -79,7 +75,6 @@ def load_attention(config, prefix, weights): bias=True, ) else: - # should be here return TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], @@ -88,7 +83,6 @@ def load_attention(config, prefix, weights): bias=True, ) - def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 @@ -114,7 +108,6 @@ def _load_gqa(config, prefix: str, weights): get_linear(weight, bias=True, quantize=config.quantize) ) - class FlashPhiAttention(torch.nn.Module): def __init__( self, @@ -142,7 +135,6 @@ class FlashPhiAttention(torch.nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) - # should be correct self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() @@ -162,8 +154,6 @@ class FlashPhiAttention(torch.nn.Module): ).repeat_interleave(self.num_groups) self.rotary_emb_dim = 32 - - def forward( self, hidden_states, @@ -317,7 +307,6 @@ class FlashPhiLayer(nn.Module): return hidden_states, res - class FlashPhiModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() @@ -387,7 +376,6 @@ class FlashPhiModel(torch.nn.Module): normed_hidden_states, _ = self.ln(hidden_states, residual) return normed_hidden_states - class FlashPhiForCausalLM(torch.nn.Module): def __init__(self, config, weights): super().__init__() @@ -424,6 +412,4 @@ class FlashPhiForCausalLM(torch.nn.Module): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - - return logits + return self.lm_head(hidden_states)