mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: support phi3.5 moe model loading
This commit is contained in:
parent
e790cfc0e4
commit
853bc514f2
@ -89,6 +89,22 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
if rope_type == "linear":
|
if rope_type == "linear":
|
||||||
pass
|
pass
|
||||||
|
elif rope_type == "longrope":
|
||||||
|
inv_freq = apply_phi3_scaling(
|
||||||
|
inv_freq,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
rope_theta=config.rope_theta,
|
||||||
|
short_factor=rope_scaling["short_factor"],
|
||||||
|
long_factor=rope_scaling["long_factor"],
|
||||||
|
short_mscale=rope_scaling["short_mscale"],
|
||||||
|
long_mscale=rope_scaling["long_mscale"],
|
||||||
|
original_max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
device=inv_freq.device,
|
||||||
|
dim=dim,
|
||||||
|
)
|
||||||
|
return cls(inv_freq, scaling_factor)
|
||||||
elif rope_type == "dynamic":
|
elif rope_type == "dynamic":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
return DynamicPositionRotaryEmbedding(
|
return DynamicPositionRotaryEmbedding(
|
||||||
@ -475,3 +491,58 @@ def apply_llama3_scaling(
|
|||||||
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
||||||
|
|
||||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_phi3_scaling(
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
*,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
rope_theta: int,
|
||||||
|
short_factor: torch.Tensor,
|
||||||
|
long_factor: torch.Tensor,
|
||||||
|
short_mscale: float,
|
||||||
|
long_mscale: float,
|
||||||
|
original_max_position_embeddings: int,
|
||||||
|
device=None,
|
||||||
|
dim=None,
|
||||||
|
):
|
||||||
|
base = rope_theta
|
||||||
|
long_rescale_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
||||||
|
short_rescale_factors = torch.tensor(
|
||||||
|
short_factor, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
long_inv_freq = 1.0 / (
|
||||||
|
long_rescale_factors
|
||||||
|
* (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||||
|
)
|
||||||
|
short_inv_freq = 1.0 / (
|
||||||
|
short_rescale_factors
|
||||||
|
* (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||||
|
)
|
||||||
|
# original_max_position_embeddings = torch.tensor(original_max_position_embeddings, device=device)
|
||||||
|
# low_freq_factor = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
||||||
|
# high_freq_factor = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# low_freq_wavelen = original_max_position_embeddings / low_freq_factor
|
||||||
|
# high_freq_wavelen = original_max_position_embeddings / high_freq_factor
|
||||||
|
new_freqs = []
|
||||||
|
|
||||||
|
for freq in freqs:
|
||||||
|
wavelen = 2 * math.pi / freq
|
||||||
|
|
||||||
|
# if wavelen < high_freq_wavelen:
|
||||||
|
if True:
|
||||||
|
new_freqs.append(freq)
|
||||||
|
elif wavelen > low_freq_wavelen:
|
||||||
|
new_freqs.append(freq / short_mscale)
|
||||||
|
else:
|
||||||
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
|
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||||
|
high_freq_factor - low_freq_factor
|
||||||
|
)
|
||||||
|
new_freqs.append(
|
||||||
|
(1 - smooth) * freq / short_mscale + smooth * freq / long_mscale
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||||
|
@ -237,6 +237,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Phi",
|
"name": "Phi",
|
||||||
"url": "https://huggingface.co/microsoft/phi-1_5",
|
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||||
}
|
}
|
||||||
|
PHI_MOE = {
|
||||||
|
"type": "phimoe",
|
||||||
|
"name": "PhiMoe",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
|
||||||
|
}
|
||||||
BAICHUAN = {
|
BAICHUAN = {
|
||||||
"type": "baichuan",
|
"type": "baichuan",
|
||||||
"name": "Baichuan",
|
"name": "Baichuan",
|
||||||
@ -768,6 +773,28 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif model_type == PHI_MOE:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashPhiForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=True, # trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
elif model_type == "phi-msft":
|
elif model_type == "phi-msft":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -27,6 +27,15 @@ from text_generation_server.layers.rotary import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM != "ipex":
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -69,7 +78,16 @@ class PhiConfig(PretrainedConfig):
|
|||||||
|
|
||||||
# this is the same as llama except for Phi uses bias=True
|
# this is the same as llama except for Phi uses bias=True
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config._name_or_path == "microsoft/Phi-3.5-MoE-instruct":
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads \
|
||||||
|
and False:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
@ -90,7 +108,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if config.quantize and (config.quantize not in ["gptq", "awq", "marlin"]):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
@ -105,6 +123,103 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
return TensorParallelColumnLinear(get_linear(weight, bias=True))
|
return TensorParallelColumnLinear(get_linear(weight, bias=True))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_experts(config, prefix: str, mat, weights):
|
||||||
|
if config.quantize is not None:
|
||||||
|
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||||
|
|
||||||
|
assert mat in ["w1", "w2", "w3"]
|
||||||
|
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
config.intermediate_size % world_size == 0
|
||||||
|
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
|
block_size = config.intermediate_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
tensor = torch.empty(
|
||||||
|
(config.num_local_experts * block_size, config.hidden_size),
|
||||||
|
dtype=weights.dtype,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(config.num_local_experts):
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||||
|
|
||||||
|
if mat == "w2":
|
||||||
|
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||||
|
else:
|
||||||
|
expert_slice = slice_[start:stop]
|
||||||
|
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||||
|
dtype=weights.dtype
|
||||||
|
).to(device=weights.device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSparseMoE(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
||||||
|
self.num_experts = config.num_local_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
|
act = config.hidden_act
|
||||||
|
if "gelu" in act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif "silu" in act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[act]
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||||
|
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||||
|
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||||
|
)
|
||||||
|
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||||
|
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||||
|
)
|
||||||
|
self.w13 = torch.cat([w1, w3], dim=1)
|
||||||
|
self.w2 = (
|
||||||
|
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||||
|
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
out = fused_moe(
|
||||||
|
x,
|
||||||
|
self.w13,
|
||||||
|
self.w2,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiAttention(torch.nn.Module):
|
class FlashPhiAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -118,7 +233,12 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
|
||||||
|
self.head_dim = self.head_size
|
||||||
|
if hasattr(config, "partial_rotary_factor"):
|
||||||
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||||
|
else:
|
||||||
|
self.rotary_dim = self.head_size
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
@ -141,9 +261,10 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
# in llama the dense layer is called "o_proj" and has bias=False
|
# in llama the dense layer is called "o_proj" and has bias=False
|
||||||
|
proj_layer_name = f"{prefix}.o_proj" if self.use_moe else f"{prefix}.dense"
|
||||||
self.dense = TensorParallelRowLinear.load(
|
self.dense = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.dense",
|
prefix=proj_layer_name,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
@ -183,6 +304,11 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
|
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
|
||||||
#
|
#
|
||||||
# Apply partial positional embeddings in place
|
# Apply partial positional embeddings in place
|
||||||
|
if self.use_moe:
|
||||||
|
# rotate half rotary_dim
|
||||||
|
# half_size = torch.select(kv, dim=1, index=0).size(1) // 2
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
else:
|
||||||
self.rotary_emb(
|
self.rotary_emb(
|
||||||
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
|
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
|
||||||
)
|
)
|
||||||
@ -258,13 +384,30 @@ class FlashPhiLayer(nn.Module):
|
|||||||
self.self_attn = FlashPhiAttention(
|
self.self_attn = FlashPhiAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
self.moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
||||||
|
else:
|
||||||
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
self.input_layernorm = FastLayerNorm.load(
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
# eps=config.layer_norm_eps,
|
||||||
|
eps=1e-5,
|
||||||
)
|
)
|
||||||
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
|
||||||
|
if self.use_moe:
|
||||||
|
self.post_attn_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
||||||
|
self.resid_dropout = torch.nn.Dropout(config.attention_dropout)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -293,6 +436,25 @@ class FlashPhiLayer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
# hidden_states = attn_output
|
||||||
|
# if residual is not None:
|
||||||
|
# hidden_states = hidden_states + residual
|
||||||
|
# residual = hidden_states
|
||||||
|
# hidden_states, router_states = self.post_attn_layernorm(
|
||||||
|
# hidden_states
|
||||||
|
# )
|
||||||
|
# hidden_states = self.moe(hidden_states)
|
||||||
|
# # hidden_states = hidden_states + residual
|
||||||
|
# res = residual
|
||||||
|
|
||||||
|
hidden_states = self.resid_dropout(attn_output)
|
||||||
|
_hidden_states = self.resid_dropout(self.moe(hidden_states))
|
||||||
|
hidden_states = hidden_states.add(_hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
hidden_states = self.resid_dropout(attn_output).add(
|
hidden_states = self.resid_dropout(attn_output).add(
|
||||||
self.resid_dropout(self.mlp(hidden_states))
|
self.resid_dropout(self.mlp(hidden_states))
|
||||||
)
|
)
|
||||||
@ -327,6 +489,15 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
self.norm = FastLayerNorm.load(
|
||||||
|
prefix="model.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-5,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.norm = FastLayerNorm.load(
|
self.norm = FastLayerNorm.load(
|
||||||
prefix="model.final_layernorm",
|
prefix="model.final_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
Loading…
Reference in New Issue
Block a user