mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: prefer llama base model and improve rotary logic
This commit is contained in:
parent
853bc514f2
commit
dff1b9f795
@ -90,21 +90,43 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
if rope_type == "linear":
|
||||
pass
|
||||
elif rope_type == "longrope":
|
||||
inv_freq = apply_phi3_scaling(
|
||||
inv_freq,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
rope_theta=config.rope_theta,
|
||||
short_factor=rope_scaling["short_factor"],
|
||||
long_factor=rope_scaling["long_factor"],
|
||||
short_mscale=rope_scaling["short_mscale"],
|
||||
long_mscale=rope_scaling["long_mscale"],
|
||||
original_max_position_embeddings=rope_scaling[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
device=inv_freq.device,
|
||||
dim=dim,
|
||||
# Phi3LongRoPEScaledRotaryEmbedding
|
||||
short_factor = torch.tensor(
|
||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
return cls(inv_freq, scaling_factor)
|
||||
long_factor = torch.tensor(
|
||||
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
short_mscale = rope_scaling["short_mscale"]
|
||||
long_mscale = rope_scaling["long_mscale"]
|
||||
original_max_position_embeddings = (
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
return Phi3LongRoPEScaledRotaryEmbedding(
|
||||
short_inv_freq=1.0
|
||||
/ (
|
||||
short_factor
|
||||
* base
|
||||
** (
|
||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
),
|
||||
long_inv_freq=1.0
|
||||
/ (
|
||||
long_factor
|
||||
* base
|
||||
** (
|
||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
),
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
short_mscale=short_mscale,
|
||||
long_mscale=long_mscale,
|
||||
original_max_position_embeddings=original_max_position_embeddings,
|
||||
)
|
||||
|
||||
elif rope_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
@ -324,6 +346,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||
|
||||
|
||||
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
short_inv_freq,
|
||||
long_inv_freq,
|
||||
max_position_embeddings,
|
||||
short_mscale,
|
||||
long_mscale,
|
||||
original_max_position_embeddings,
|
||||
):
|
||||
super(PositionRotaryEmbedding, self).__init__()
|
||||
self.short_inv_freq = short_inv_freq
|
||||
self.long_inv_freq = long_inv_freq
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.short_mscale = short_mscale
|
||||
self.long_mscale = long_mscale
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
|
||||
# cache
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||
|
||||
short_freqs = torch.outer(
|
||||
t[: self.original_max_position_embeddings],
|
||||
self.short_inv_freq.to(device=t.device),
|
||||
)
|
||||
|
||||
long_freqs = torch.outer(
|
||||
t[self.original_max_position_embeddings :],
|
||||
self.long_inv_freq.to(device=t.device),
|
||||
)
|
||||
|
||||
short_freqs = short_freqs * self.short_mscale
|
||||
long_freqs = long_freqs * self.long_mscale
|
||||
|
||||
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
|
||||
freqs[: self.original_max_position_embeddings] = short_freqs
|
||||
freqs[self.original_max_position_embeddings :] = long_freqs
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
@ -491,58 +570,3 @@ def apply_llama3_scaling(
|
||||
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
||||
|
||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
||||
|
||||
def apply_phi3_scaling(
|
||||
freqs: torch.Tensor,
|
||||
*,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: int,
|
||||
short_factor: torch.Tensor,
|
||||
long_factor: torch.Tensor,
|
||||
short_mscale: float,
|
||||
long_mscale: float,
|
||||
original_max_position_embeddings: int,
|
||||
device=None,
|
||||
dim=None,
|
||||
):
|
||||
base = rope_theta
|
||||
long_rescale_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
||||
short_rescale_factors = torch.tensor(
|
||||
short_factor, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
long_inv_freq = 1.0 / (
|
||||
long_rescale_factors
|
||||
* (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
)
|
||||
short_inv_freq = 1.0 / (
|
||||
short_rescale_factors
|
||||
* (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
)
|
||||
# original_max_position_embeddings = torch.tensor(original_max_position_embeddings, device=device)
|
||||
# low_freq_factor = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
||||
# high_freq_factor = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
||||
|
||||
# low_freq_wavelen = original_max_position_embeddings / low_freq_factor
|
||||
# high_freq_wavelen = original_max_position_embeddings / high_freq_factor
|
||||
new_freqs = []
|
||||
|
||||
for freq in freqs:
|
||||
wavelen = 2 * math.pi / freq
|
||||
|
||||
# if wavelen < high_freq_wavelen:
|
||||
if True:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / short_mscale)
|
||||
else:
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||
high_freq_factor - low_freq_factor
|
||||
)
|
||||
new_freqs.append(
|
||||
(1 - smooth) * freq / short_mscale + smooth * freq / long_mscale
|
||||
)
|
||||
|
||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
@ -777,7 +777,7 @@ def get_model(
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashPhiForCausalLM,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
|
@ -47,11 +47,17 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
)
|
||||
from text_generation_server.utils.weights import (
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
@ -245,6 +251,103 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def _load_experts(config, prefix: str, mat, weights):
|
||||
if config.quantize is not None:
|
||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||
|
||||
assert mat in ["w1", "w2", "w3"]
|
||||
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
assert (
|
||||
config.intermediate_size % world_size == 0
|
||||
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
block_size = config.intermediate_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
tensor = torch.empty(
|
||||
(config.num_local_experts * block_size, config.hidden_size),
|
||||
dtype=weights.dtype,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
for i in range(config.num_local_experts):
|
||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||
|
||||
if mat == "w2":
|
||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||
else:
|
||||
expert_slice = slice_[start:stop]
|
||||
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||
dtype=weights.dtype
|
||||
).to(device=weights.device)
|
||||
return tensor
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
act = config.hidden_act
|
||||
if "gelu" in act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in act:
|
||||
self.act = torch.nn.functional.silu
|
||||
else:
|
||||
self.act = ACT2FN[act]
|
||||
|
||||
# gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.w13 = torch.cat([w1, w3], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, x, adapter_data) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
self.w13,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights, index):
|
||||
super().__init__()
|
||||
@ -353,8 +456,13 @@ class FlashLlamaLayer(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
|
||||
|
||||
if self.use_moe:
|
||||
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
||||
else:
|
||||
self.dense = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
@ -401,7 +509,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
mlp_output = self.dense(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
@ -27,15 +27,6 @@ from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
@ -78,16 +69,7 @@ class PhiConfig(PretrainedConfig):
|
||||
|
||||
# this is the same as llama except for Phi uses bias=True
|
||||
def load_attention(config, prefix, weights):
|
||||
if config._name_or_path == "microsoft/Phi-3.5-MoE-instruct":
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
if config.num_attention_heads != config.num_key_value_heads \
|
||||
and False:
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
@ -108,7 +90,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize and (config.quantize not in ["gptq", "awq", "marlin"]):
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
@ -123,103 +105,6 @@ def _load_gqa(config, prefix: str, weights):
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=True))
|
||||
|
||||
|
||||
def _load_experts(config, prefix: str, mat, weights):
|
||||
if config.quantize is not None:
|
||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||
|
||||
assert mat in ["w1", "w2", "w3"]
|
||||
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
assert (
|
||||
config.intermediate_size % world_size == 0
|
||||
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
block_size = config.intermediate_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
tensor = torch.empty(
|
||||
(config.num_local_experts * block_size, config.hidden_size),
|
||||
dtype=weights.dtype,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
for i in range(config.num_local_experts):
|
||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||
|
||||
if mat == "w2":
|
||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||
else:
|
||||
expert_slice = slice_[start:stop]
|
||||
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||
dtype=weights.dtype
|
||||
).to(device=weights.device)
|
||||
return tensor
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
act = config.hidden_act
|
||||
if "gelu" in act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in act:
|
||||
self.act = torch.nn.functional.silu
|
||||
else:
|
||||
self.act = ACT2FN[act]
|
||||
|
||||
# gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.w13 = torch.cat([w1, w3], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
self.w13,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class FlashPhiAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -233,12 +118,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
|
||||
self.head_dim = self.head_size
|
||||
if hasattr(config, "partial_rotary_factor"):
|
||||
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||
else:
|
||||
self.rotary_dim = self.head_size
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
@ -261,10 +141,9 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
# in llama the dense layer is called "o_proj" and has bias=False
|
||||
proj_layer_name = f"{prefix}.o_proj" if self.use_moe else f"{prefix}.dense"
|
||||
self.dense = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=proj_layer_name,
|
||||
prefix=f"{prefix}.dense",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
@ -304,11 +183,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
|
||||
#
|
||||
# Apply partial positional embeddings in place
|
||||
if self.use_moe:
|
||||
# rotate half rotary_dim
|
||||
# half_size = torch.select(kv, dim=1, index=0).size(1) // 2
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
else:
|
||||
self.rotary_emb(
|
||||
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
|
||||
)
|
||||
@ -384,30 +258,13 @@ class FlashPhiLayer(nn.Module):
|
||||
self.self_attn = FlashPhiAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
|
||||
|
||||
if self.use_moe:
|
||||
self.moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
||||
else:
|
||||
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
# eps=config.layer_norm_eps,
|
||||
eps=1e-5,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
if self.use_moe:
|
||||
self.post_attn_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=1e-5,
|
||||
)
|
||||
|
||||
# self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
||||
self.resid_dropout = torch.nn.Dropout(config.attention_dropout)
|
||||
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -436,25 +293,6 @@ class FlashPhiLayer(nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
||||
if self.use_moe:
|
||||
# hidden_states = attn_output
|
||||
# if residual is not None:
|
||||
# hidden_states = hidden_states + residual
|
||||
# residual = hidden_states
|
||||
# hidden_states, router_states = self.post_attn_layernorm(
|
||||
# hidden_states
|
||||
# )
|
||||
# hidden_states = self.moe(hidden_states)
|
||||
# # hidden_states = hidden_states + residual
|
||||
# res = residual
|
||||
|
||||
hidden_states = self.resid_dropout(attn_output)
|
||||
_hidden_states = self.resid_dropout(self.moe(hidden_states))
|
||||
hidden_states = hidden_states.add(_hidden_states)
|
||||
|
||||
|
||||
else:
|
||||
hidden_states = self.resid_dropout(attn_output).add(
|
||||
self.resid_dropout(self.mlp(hidden_states))
|
||||
)
|
||||
@ -489,15 +327,6 @@ class FlashPhiModel(torch.nn.Module):
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
|
||||
|
||||
if self.use_moe:
|
||||
self.norm = FastLayerNorm.load(
|
||||
prefix="model.norm",
|
||||
weights=weights,
|
||||
eps=1e-5,
|
||||
)
|
||||
else:
|
||||
self.norm = FastLayerNorm.load(
|
||||
prefix="model.final_layernorm",
|
||||
weights=weights,
|
||||
|
Loading…
Reference in New Issue
Block a user