mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: adjust model config vars and other refactors
This commit is contained in:
parent
99392376e6
commit
9bcd21a0b0
@ -38,6 +38,7 @@ class PhiConfig(PretrainedConfig):
|
|||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
resid_pdrop=0.1, # llama doesn't have this
|
resid_pdrop=0.1, # llama doesn't have this
|
||||||
|
partial_rotary_factor=0.5,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -54,6 +55,7 @@ class PhiConfig(PretrainedConfig):
|
|||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.resid_pdrop = resid_pdrop
|
self.resid_pdrop = resid_pdrop
|
||||||
|
self.partial_rotary_factor = partial_rotary_factor
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
@ -67,14 +69,6 @@ class PhiConfig(PretrainedConfig):
|
|||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
|
||||||
if config.model_type == "baichuan":
|
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}.W_pack",
|
|
||||||
weights=weights,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
@ -130,6 +124,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -188,7 +183,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
#
|
#
|
||||||
# Apply partial positional embeddings in place
|
# Apply partial positional embeddings in place
|
||||||
self.rotary_emb(
|
self.rotary_emb(
|
||||||
query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads],
|
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
|
||||||
cos, sin
|
cos, sin
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -243,7 +238,7 @@ class PhiMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# llama weights are up_proj and down_proj and bias=False
|
# llama weights are up_proj and down_proj and bias=False
|
||||||
self.gate_up_proj = TensorParallelRowLinear.load(
|
self.up_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.fc1",
|
prefix=f"{prefix}.fc1",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -259,9 +254,7 @@ class PhiMLP(nn.Module):
|
|||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# NOTE: Llama requires the gate up states to an intermediate size
|
# NOTE: Llama requires the gate up states to an intermediate size
|
||||||
# Phi does not and we can avoid the `view` operation
|
# Phi does not and we can avoid the `view` operation
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
return self.down_proj(self.act(self.up_proj(hidden_states)))
|
||||||
post_act = self.act(gate_up_states)
|
|
||||||
return self.down_proj(post_act)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiLayer(nn.Module):
|
class FlashPhiLayer(nn.Module):
|
||||||
@ -304,10 +297,7 @@ class FlashPhiLayer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self.resid_dropout(attn_output)
|
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
|
||||||
|
|
||||||
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
|
||||||
hidden_states = attn_output + feed_forward_hidden_states
|
|
||||||
|
|
||||||
return hidden_states, res
|
return hidden_states, res
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user