fix: remove unused imports and duplicate spaces

This commit is contained in:
drbh 2024-01-23 00:18:29 +00:00
parent 2b43c5b0dd
commit c49332adb6

View File

@ -14,9 +14,7 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm,
FastLayerNorm,
FastLinear,
)
class PhiConfig(PretrainedConfig):
@ -65,10 +63,8 @@ class PhiConfig(PretrainedConfig):
**kwargs,
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
# should never get here
return _load_gqa(config, prefix, weights)
else:
if config.model_type == "baichuan":
@ -79,7 +75,6 @@ def load_attention(config, prefix, weights):
bias=True,
)
else:
# should be here
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
@ -88,7 +83,6 @@ def load_attention(config, prefix, weights):
bias=True,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
@ -114,7 +108,6 @@ def _load_gqa(config, prefix: str, weights):
get_linear(weight, bias=True, quantize=config.quantize)
)
class FlashPhiAttention(torch.nn.Module):
def __init__(
self,
@ -142,7 +135,6 @@ class FlashPhiAttention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}"
)
# should be correct
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
@ -162,8 +154,6 @@ class FlashPhiAttention(torch.nn.Module):
).repeat_interleave(self.num_groups)
self.rotary_emb_dim = 32
def forward(
self,
hidden_states,
@ -317,7 +307,6 @@ class FlashPhiLayer(nn.Module):
return hidden_states, res
class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
@ -387,7 +376,6 @@ class FlashPhiModel(torch.nn.Module):
normed_hidden_states, _ = self.ln(hidden_states, residual)
return normed_hidden_states
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
@ -424,6 +412,4 @@ class FlashPhiForCausalLM(torch.nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
return self.lm_head(hidden_states)