mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: remove unused imports and duplicate spaces
This commit is contained in:
parent
2b43c5b0dd
commit
c49332adb6
@ -14,9 +14,7 @@ from text_generation_server.utils.layers import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
FastLinear,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
@ -65,10 +63,8 @@ class PhiConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
# should never get here
|
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
if config.model_type == "baichuan":
|
if config.model_type == "baichuan":
|
||||||
@ -79,7 +75,6 @@ def load_attention(config, prefix, weights):
|
|||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# should be here
|
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
@ -88,7 +83,6 @@ def load_attention(config, prefix, weights):
|
|||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
assert config.hidden_size % config.num_attention_heads == 0
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
@ -114,7 +108,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
get_linear(weight, bias=True, quantize=config.quantize)
|
get_linear(weight, bias=True, quantize=config.quantize)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiAttention(torch.nn.Module):
|
class FlashPhiAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -142,7 +135,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
f"and `num_shards`: {weights.process_group.size()}"
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# should be correct
|
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
@ -162,8 +154,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
self.rotary_emb_dim = 32
|
self.rotary_emb_dim = 32
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -317,7 +307,6 @@ class FlashPhiLayer(nn.Module):
|
|||||||
|
|
||||||
return hidden_states, res
|
return hidden_states, res
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiModel(torch.nn.Module):
|
class FlashPhiModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -387,7 +376,6 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
normed_hidden_states, _ = self.ln(hidden_states, residual)
|
normed_hidden_states, _ = self.ln(hidden_states, residual)
|
||||||
return normed_hidden_states
|
return normed_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiForCausalLM(torch.nn.Module):
|
class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -424,6 +412,4 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
||||||
logits = self.lm_head(hidden_states)
|
return self.lm_head(hidden_states)
|
||||||
|
|
||||||
return logits
|
|
||||||
|
Loading…
Reference in New Issue
Block a user