mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: remove unused imports and duplicate spaces
This commit is contained in:
parent
2b43c5b0dd
commit
c49332adb6
@ -14,9 +14,7 @@ from text_generation_server.utils.layers import (
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
FastLayerNorm,
|
||||
FastLinear,
|
||||
)
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
@ -65,10 +63,8 @@ class PhiConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
# should never get here
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
if config.model_type == "baichuan":
|
||||
@ -79,7 +75,6 @@ def load_attention(config, prefix, weights):
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
# should be here
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
@ -88,7 +83,6 @@ def load_attention(config, prefix, weights):
|
||||
bias=True,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
@ -114,7 +108,6 @@ def _load_gqa(config, prefix: str, weights):
|
||||
get_linear(weight, bias=True, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashPhiAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -142,7 +135,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
|
||||
# should be correct
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
@ -162,8 +154,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
).repeat_interleave(self.num_groups)
|
||||
self.rotary_emb_dim = 32
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -317,7 +307,6 @@ class FlashPhiLayer(nn.Module):
|
||||
|
||||
return hidden_states, res
|
||||
|
||||
|
||||
class FlashPhiModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
@ -387,7 +376,6 @@ class FlashPhiModel(torch.nn.Module):
|
||||
normed_hidden_states, _ = self.ln(hidden_states, residual)
|
||||
return normed_hidden_states
|
||||
|
||||
|
||||
class FlashPhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
@ -424,6 +412,4 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits
|
||||
return self.lm_head(hidden_states)
|
||||
|
Loading…
Reference in New Issue
Block a user