mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: improve model initalization
This commit is contained in:
parent
77ee1f18fa
commit
43441cad42
@ -24,7 +24,6 @@ class PhiConfig(PretrainedConfig):
|
|||||||
self,
|
self,
|
||||||
vocab_size=51200,
|
vocab_size=51200,
|
||||||
hidden_size=2560,
|
hidden_size=2560,
|
||||||
intermediate_size=11008,
|
|
||||||
num_hidden_layers=32,
|
num_hidden_layers=32,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
num_key_value_heads=None,
|
num_key_value_heads=None,
|
||||||
@ -40,12 +39,12 @@ class PhiConfig(PretrainedConfig):
|
|||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
|
resid_pdrop=0.1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
@ -61,6 +60,7 @@ class PhiConfig(PretrainedConfig):
|
|||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
|
self.resid_pdrop = resid_pdrop
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
@ -81,7 +81,7 @@ def load_attention(config, prefix, weights):
|
|||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.W_pack",
|
prefix=f"{prefix}.W_pack",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# should be here
|
# should be here
|
||||||
@ -90,7 +90,7 @@ def load_attention(config, prefix, weights):
|
|||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
return TensorParallelColumnLinear(
|
return TensorParallelColumnLinear(
|
||||||
get_linear(weight, bias=None, quantize=config.quantize)
|
get_linear(weight, bias=True, quantize=config.quantize)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -130,6 +130,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
# should be 80 = 2560 / 32
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
# MAYBE (if not static)
|
# MAYBE (if not static)
|
||||||
@ -160,7 +161,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.dense",
|
prefix=f"{prefix}.dense",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=True,
|
||||||
)
|
)
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
@ -180,6 +181,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
# shape = torch.Size([4096, 7680])
|
||||||
|
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
@ -187,6 +190,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
# query = torch.Size([4096, 2560])
|
||||||
|
# kv = torch.Size([4096, 5120])
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
@ -201,6 +206,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
|
print("🧢 flash attention")
|
||||||
|
print("cu_seqlen_prefill", cu_seqlen_prefill.shape)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
@ -213,6 +220,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
print("📗 paged attention")
|
||||||
paged_attention.attention(
|
paged_attention.attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
@ -225,6 +233,11 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: remove this - only used to summarize attention weights
|
||||||
|
# get sum of the attention weights
|
||||||
|
my_sum = torch.sum(attn_output, dim=2)
|
||||||
|
print("my_sum", my_sum)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
@ -247,16 +260,17 @@ class PhiMLP(nn.Module):
|
|||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.fc1",
|
prefix=f"{prefix}.fc1",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=True,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.fc2",
|
prefix=f"{prefix}.fc2",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
print("FORWARD MLP")
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
post_act = self.act(gate_up_states)
|
post_act = self.act(gate_up_states)
|
||||||
return self.down_proj(post_act)
|
return self.down_proj(post_act)
|
||||||
@ -274,6 +288,7 @@ class FlashPhiLayer(nn.Module):
|
|||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -288,11 +303,14 @@ class FlashPhiLayer(nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
print("💧 FORWARD LAYER")
|
||||||
|
print("\tinput0", hidden_states[0][1])
|
||||||
|
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
print("\tnormalized shape", hidden_states.shape)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output = self.self_attn(
|
attn_output = self.self_attn(
|
||||||
normed_hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
@ -303,12 +321,12 @@ class FlashPhiLayer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_hidden_states)
|
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
||||||
|
hidden_states = attn_output + feed_forward_hidden_states
|
||||||
|
|
||||||
result = attn_output + mlp_output + res
|
return hidden_states, res
|
||||||
|
|
||||||
return result, res
|
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiModel(torch.nn.Module):
|
class FlashPhiModel(torch.nn.Module):
|
||||||
@ -337,6 +355,12 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
self.ln = FastLayerNorm.load(
|
||||||
|
prefix="model.final_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -371,7 +395,8 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
normed_hidden_states, _ = self.ln(hidden_states, residual)
|
||||||
|
return normed_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiForCausalLM(torch.nn.Module):
|
class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
@ -379,41 +404,22 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashPhiModel(config, weights)
|
self.model = FlashPhiModel(config, weights)
|
||||||
# self.lm_head = TensorParallelHead.load(
|
self.lm_head = TensorParallelHead.load(
|
||||||
# config,
|
|
||||||
# prefix="lm_head",
|
|
||||||
# weights=weights,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# TODO: prefer parallel head
|
|
||||||
self.linear = FastLinear.load(
|
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: use in correct place
|
|
||||||
self.ln = FastLayerNorm.load(
|
|
||||||
prefix="model.final_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.rms_norm_eps,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
# 1000K and 10K
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor], # indexes for the items in the batch
|
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
# paged attention related
|
block_tables: torch.Tensor,
|
||||||
block_tables: torch.Tensor, # <- indexes into blocks
|
slots: torch.Tensor,
|
||||||
slots: torch.Tensor, # <- indexes into mem
|
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
# both attentions
|
max_s: int,
|
||||||
max_s: int, # <- max sequence length (make kernals chose swap)
|
|
||||||
# small opt (only care about final)
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
@ -429,7 +435,6 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
||||||
normed_hidden_states, res = self.ln(hidden_states, None)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = self.linear(normed_hidden_states)
|
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
Loading…
Reference in New Issue
Block a user