From 853bc514f2d55f833d893ef3691d5eb130117692 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 30 Aug 2024 15:14:59 +0000 Subject: [PATCH] feat: support phi3.5 moe model loading --- .../text_generation_server/layers/rotary.py | 71 ++++++ .../text_generation_server/models/__init__.py | 27 +++ .../custom_modeling/flash_phi_modeling.py | 207 ++++++++++++++++-- 3 files changed, 287 insertions(+), 18 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index fc4a59b9..09ec0dd1 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -89,6 +89,22 @@ class PositionRotaryEmbedding(nn.Module): if rope_type == "linear": pass + elif rope_type == "longrope": + inv_freq = apply_phi3_scaling( + inv_freq, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + short_factor=rope_scaling["short_factor"], + long_factor=rope_scaling["long_factor"], + short_mscale=rope_scaling["short_mscale"], + long_mscale=rope_scaling["long_mscale"], + original_max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + device=inv_freq.device, + dim=dim, + ) + return cls(inv_freq, scaling_factor) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -475,3 +491,58 @@ def apply_llama3_scaling( new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +def apply_phi3_scaling( + freqs: torch.Tensor, + *, + max_position_embeddings: int, + rope_theta: int, + short_factor: torch.Tensor, + long_factor: torch.Tensor, + short_mscale: float, + long_mscale: float, + original_max_position_embeddings: int, + device=None, + dim=None, +): + base = rope_theta + long_rescale_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) + short_rescale_factors = torch.tensor( + short_factor, dtype=torch.float32, device=device + ) + + long_inv_freq = 1.0 / ( + long_rescale_factors + * (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + ) + short_inv_freq = 1.0 / ( + short_rescale_factors + * (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + ) + # original_max_position_embeddings = torch.tensor(original_max_position_embeddings, device=device) + # low_freq_factor = torch.tensor(long_factor, dtype=torch.float32, device=device) + # high_freq_factor = torch.tensor(short_factor, dtype=torch.float32, device=device) + + # low_freq_wavelen = original_max_position_embeddings / low_freq_factor + # high_freq_wavelen = original_max_position_embeddings / high_freq_factor + new_freqs = [] + + for freq in freqs: + wavelen = 2 * math.pi / freq + + # if wavelen < high_freq_wavelen: + if True: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / short_mscale) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append( + (1 - smooth) * freq / short_mscale + smooth * freq / long_mscale + ) + + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e5e5aabb..f22c5251 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -237,6 +237,11 @@ class ModelType(enum.Enum): "name": "Phi", "url": "https://huggingface.co/microsoft/phi-1_5", } + PHI_MOE = { + "type": "phimoe", + "name": "PhiMoe", + "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct", + } BAICHUAN = { "type": "baichuan", "name": "Baichuan", @@ -768,6 +773,28 @@ def get_model( trust_remote_code=trust_remote_code, ) + elif model_type == PHI_MOE: + if FLASH_ATTENTION: + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=True, # trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == "phi-msft": if FLASH_ATTENTION: raise NotImplementedError( diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 2a0dc606..047eca42 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -27,6 +27,15 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.layers import ( + FastLinear, +) + +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM != "ipex": + from vllm.model_executor.layers.fused_moe import fused_moe + class PhiConfig(PretrainedConfig): def __init__( @@ -69,7 +78,16 @@ class PhiConfig(PretrainedConfig): # this is the same as llama except for Phi uses bias=True def load_attention(config, prefix, weights): - if config.num_attention_heads != config.num_key_value_heads: + if config._name_or_path == "microsoft/Phi-3.5-MoE-instruct": + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + if config.num_attention_heads != config.num_key_value_heads \ + and False: return _load_gqa(config, prefix, weights) else: return TensorParallelColumnLinear.load_multi( @@ -90,7 +108,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: + if config.quantize and (config.quantize not in ["gptq", "awq", "marlin"]): weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -105,6 +123,103 @@ def _load_gqa(config, prefix: str, weights): return TensorParallelColumnLinear(get_linear(weight, bias=True)) +def _load_experts(config, prefix: str, mat, weights): + if config.quantize is not None: + raise NotImplementedError("Mixtral does not support weight quantization yet.") + + assert mat in ["w1", "w2", "w3"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.intermediate_size % world_size == 0 + ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.num_local_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) + + for i in range(config.num_local_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "w2": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +class BlockSparseMoE(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size // weights.process_group.size() + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + act = config.hidden_act + if "gelu" in act: + self.act = lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + elif "silu" in act: + self.act = torch.nn.functional.silu + else: + self.act = ACT2FN[act] + + # gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) + w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + self.w13 = torch.cat([w1, w3], dim=1) + self.w2 = ( + _load_experts(config, f"{prefix}.experts", "w2", weights) + .view(self.num_experts, self.ffn_dim, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = fused_moe( + x, + self.w13, + self.w2, + router_logits, + self.top_k, + renormalize=True, + inplace=True, + ) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + class FlashPhiAttention(torch.nn.Module): def __init__( self, @@ -118,7 +233,12 @@ class FlashPhiAttention(torch.nn.Module): self.head_size = self.hidden_size // self.num_heads self.softmax_scale = self.head_size**-0.5 - self.rotary_dim = int(config.partial_rotary_factor * self.head_size) + self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct" + self.head_dim = self.head_size + if hasattr(config, "partial_rotary_factor"): + self.rotary_dim = int(config.partial_rotary_factor * self.head_size) + else: + self.rotary_dim = self.head_size self.rotary_emb = PositionRotaryEmbedding.static( config=config, @@ -141,9 +261,10 @@ class FlashPhiAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) # in llama the dense layer is called "o_proj" and has bias=False + proj_layer_name = f"{prefix}.o_proj" if self.use_moe else f"{prefix}.dense" self.dense = TensorParallelRowLinear.load( config, - prefix=f"{prefix}.dense", + prefix=proj_layer_name, weights=weights, bias=True, ) @@ -183,9 +304,14 @@ class FlashPhiAttention(torch.nn.Module): # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions # # Apply partial positional embeddings in place - self.rotary_emb( - query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin - ) + if self.use_moe: + # rotate half rotary_dim + # half_size = torch.select(kv, dim=1, index=0).size(1) // 2 + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + else: + self.rotary_emb( + query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin + ) # Reshape key and value and cache reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) @@ -258,13 +384,30 @@ class FlashPhiLayer(nn.Module): self.self_attn = FlashPhiAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) - self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct" + + if self.use_moe: + self.moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) + else: + self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, - eps=config.layer_norm_eps, + # eps=config.layer_norm_eps, + eps=1e-5, ) - self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) + + if self.use_moe: + self.post_attn_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=1e-5, + ) + + # self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) + self.resid_dropout = torch.nn.Dropout(config.attention_dropout) def forward( self, @@ -293,9 +436,28 @@ class FlashPhiLayer(nn.Module): max_s, ) - hidden_states = self.resid_dropout(attn_output).add( - self.resid_dropout(self.mlp(hidden_states)) - ) + + if self.use_moe: + # hidden_states = attn_output + # if residual is not None: + # hidden_states = hidden_states + residual + # residual = hidden_states + # hidden_states, router_states = self.post_attn_layernorm( + # hidden_states + # ) + # hidden_states = self.moe(hidden_states) + # # hidden_states = hidden_states + residual + # res = residual + + hidden_states = self.resid_dropout(attn_output) + _hidden_states = self.resid_dropout(self.moe(hidden_states)) + hidden_states = hidden_states.add(_hidden_states) + + + else: + hidden_states = self.resid_dropout(attn_output).add( + self.resid_dropout(self.mlp(hidden_states)) + ) return hidden_states, res @@ -327,11 +489,20 @@ class FlashPhiModel(torch.nn.Module): self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads - self.norm = FastLayerNorm.load( - prefix="model.final_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) + self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct" + + if self.use_moe: + self.norm = FastLayerNorm.load( + prefix="model.norm", + weights=weights, + eps=1e-5, + ) + else: + self.norm = FastLayerNorm.load( + prefix="model.final_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) def forward( self,