mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: adjust model config vars and other refactors
This commit is contained in:
parent
99392376e6
commit
9bcd21a0b0
@ -38,6 +38,7 @@ class PhiConfig(PretrainedConfig):
|
||||
rope_scaling=None,
|
||||
rope_theta=10000.0,
|
||||
resid_pdrop=0.1, # llama doesn't have this
|
||||
partial_rotary_factor=0.5,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -54,6 +55,7 @@ class PhiConfig(PretrainedConfig):
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
@ -67,14 +69,6 @@ class PhiConfig(PretrainedConfig):
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
if config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
@ -130,6 +124,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
@ -188,7 +183,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
#
|
||||
# Apply partial positional embeddings in place
|
||||
self.rotary_emb(
|
||||
query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads],
|
||||
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
|
||||
cos, sin
|
||||
)
|
||||
|
||||
@ -243,7 +238,7 @@ class PhiMLP(nn.Module):
|
||||
)
|
||||
|
||||
# llama weights are up_proj and down_proj and bias=False
|
||||
self.gate_up_proj = TensorParallelRowLinear.load(
|
||||
self.up_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
weights=weights,
|
||||
@ -259,9 +254,7 @@ class PhiMLP(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
# NOTE: Llama requires the gate up states to an intermediate size
|
||||
# Phi does not and we can avoid the `view` operation
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
post_act = self.act(gate_up_states)
|
||||
return self.down_proj(post_act)
|
||||
return self.down_proj(self.act(self.up_proj(hidden_states)))
|
||||
|
||||
|
||||
class FlashPhiLayer(nn.Module):
|
||||
@ -304,10 +297,7 @@ class FlashPhiLayer(nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
||||
hidden_states = attn_output + feed_forward_hidden_states
|
||||
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
|
||||
|
||||
return hidden_states, res
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user