diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 09ec0dd1..d59b68cd 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -90,21 +90,43 @@ class PositionRotaryEmbedding(nn.Module): if rope_type == "linear": pass elif rope_type == "longrope": - inv_freq = apply_phi3_scaling( - inv_freq, - max_position_embeddings=config.max_position_embeddings, - rope_theta=config.rope_theta, - short_factor=rope_scaling["short_factor"], - long_factor=rope_scaling["long_factor"], - short_mscale=rope_scaling["short_mscale"], - long_mscale=rope_scaling["long_mscale"], - original_max_position_embeddings=rope_scaling[ - "original_max_position_embeddings" - ], - device=inv_freq.device, - dim=dim, + # Phi3LongRoPEScaledRotaryEmbedding + short_factor = torch.tensor( + rope_scaling["short_factor"], dtype=torch.float32, device=device ) - return cls(inv_freq, scaling_factor) + long_factor = torch.tensor( + rope_scaling["long_factor"], dtype=torch.float32, device=device + ) + short_mscale = rope_scaling["short_mscale"] + long_mscale = rope_scaling["long_mscale"] + original_max_position_embeddings = ( + config.original_max_position_embeddings + ) + return Phi3LongRoPEScaledRotaryEmbedding( + short_inv_freq=1.0 + / ( + short_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ), + long_inv_freq=1.0 + / ( + long_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ), + max_position_embeddings=config.max_position_embeddings, + short_mscale=short_mscale, + long_mscale=long_mscale, + original_max_position_embeddings=original_max_position_embeddings, + ) + elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -324,6 +346,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype) +class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + short_inv_freq, + long_inv_freq, + max_position_embeddings, + short_mscale, + long_mscale, + original_max_position_embeddings, + ): + super(PositionRotaryEmbedding, self).__init__() + self.short_inv_freq = short_inv_freq + self.long_inv_freq = long_inv_freq + self.max_position_embeddings = max_position_embeddings + self.short_mscale = short_mscale + self.long_mscale = long_mscale + self.original_max_position_embeddings = original_max_position_embeddings + + # cache + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.dynamic_args = None + + def _update_cos_sin_cache(self, dtype, device, seqlen): + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + short_freqs = short_freqs * self.short_mscale + long_freqs = long_freqs * self.long_mscale + + freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device) + freqs[: self.original_max_position_embeddings] = short_freqs + freqs[self.original_max_position_embeddings :] = long_freqs + + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) @@ -491,58 +570,3 @@ def apply_llama3_scaling( new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - - -def apply_phi3_scaling( - freqs: torch.Tensor, - *, - max_position_embeddings: int, - rope_theta: int, - short_factor: torch.Tensor, - long_factor: torch.Tensor, - short_mscale: float, - long_mscale: float, - original_max_position_embeddings: int, - device=None, - dim=None, -): - base = rope_theta - long_rescale_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) - short_rescale_factors = torch.tensor( - short_factor, dtype=torch.float32, device=device - ) - - long_inv_freq = 1.0 / ( - long_rescale_factors - * (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - ) - short_inv_freq = 1.0 / ( - short_rescale_factors - * (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - ) - # original_max_position_embeddings = torch.tensor(original_max_position_embeddings, device=device) - # low_freq_factor = torch.tensor(long_factor, dtype=torch.float32, device=device) - # high_freq_factor = torch.tensor(short_factor, dtype=torch.float32, device=device) - - # low_freq_wavelen = original_max_position_embeddings / low_freq_factor - # high_freq_wavelen = original_max_position_embeddings / high_freq_factor - new_freqs = [] - - for freq in freqs: - wavelen = 2 * math.pi / freq - - # if wavelen < high_freq_wavelen: - if True: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / short_mscale) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append( - (1 - smooth) * freq / short_mscale + smooth * freq / long_mscale - ) - - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f22c5251..f104d99b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -777,7 +777,7 @@ def get_model( if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, - model_class=FlashPhiForCausalLM, + model_class=FlashLlamaForCausalLM, revision=revision, quantize=quantize, speculator=speculator, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 758e39aa..d6eb8080 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -47,11 +47,17 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.layers import ( + FastLinear, +) from text_generation_server.utils.weights import ( Weights, ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader +if SYSTEM != "ipex": + from vllm.model_executor.layers.fused_moe import fused_moe + if SYSTEM == "rocm": try: from vllm import _custom_C @@ -245,6 +251,103 @@ class FlashLlamaAttention(torch.nn.Module): ) +def _load_experts(config, prefix: str, mat, weights): + if config.quantize is not None: + raise NotImplementedError("Mixtral does not support weight quantization yet.") + + assert mat in ["w1", "w2", "w3"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.intermediate_size % world_size == 0 + ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.num_local_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) + + for i in range(config.num_local_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "w2": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +class BlockSparseMoE(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size // weights.process_group.size() + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + act = config.hidden_act + if "gelu" in act: + self.act = lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + elif "silu" in act: + self.act = torch.nn.functional.silu + else: + self.act = ACT2FN[act] + + # gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) + w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + self.w13 = torch.cat([w1, w3], dim=1) + self.w2 = ( + _load_experts(config, f"{prefix}.experts", "w2", weights) + .view(self.num_experts, self.ffn_dim, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) + + self.process_group = weights.process_group + + def forward(self, x, adapter_data) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = fused_moe( + x, + self.w13, + self.w2, + router_logits, + self.top_k, + renormalize=True, + inplace=True, + ) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + class LlamaMLP(nn.Module): def __init__(self, prefix, config, weights, index): super().__init__() @@ -353,9 +456,14 @@ class FlashLlamaLayer(nn.Module): weights=weights, ) - self.mlp = LlamaMLP( - prefix=f"{prefix}.mlp", config=config, weights=weights, index=index - ) + self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct" + + if self.use_moe: + self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) + else: + self.dense = LlamaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -401,7 +509,7 @@ class FlashLlamaLayer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output, adapter_data) + mlp_output = self.dense(normed_attn_res_output, adapter_data) return mlp_output, attn_res diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 047eca42..2a0dc606 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -27,15 +27,6 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) -from text_generation_server.layers import ( - FastLinear, -) - -from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM != "ipex": - from vllm.model_executor.layers.fused_moe import fused_moe - class PhiConfig(PretrainedConfig): def __init__( @@ -78,16 +69,7 @@ class PhiConfig(PretrainedConfig): # this is the same as llama except for Phi uses bias=True def load_attention(config, prefix, weights): - if config._name_or_path == "microsoft/Phi-3.5-MoE-instruct": - return TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=True, - ) - if config.num_attention_heads != config.num_key_value_heads \ - and False: + if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: return TensorParallelColumnLinear.load_multi( @@ -108,7 +90,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize and (config.quantize not in ["gptq", "awq", "marlin"]): + if config.quantize not in ["gptq", "awq", "marlin"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -123,103 +105,6 @@ def _load_gqa(config, prefix: str, weights): return TensorParallelColumnLinear(get_linear(weight, bias=True)) -def _load_experts(config, prefix: str, mat, weights): - if config.quantize is not None: - raise NotImplementedError("Mixtral does not support weight quantization yet.") - - assert mat in ["w1", "w2", "w3"] - - world_size = weights.process_group.size() - rank = weights.process_group.rank() - - assert ( - config.intermediate_size % world_size == 0 - ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" - - block_size = config.intermediate_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - - tensor = torch.empty( - (config.num_local_experts * block_size, config.hidden_size), - dtype=weights.dtype, - device=weights.device, - ) - - for i in range(config.num_local_experts): - slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") - - if mat == "w2": - expert_slice = slice_[:, start:stop].t().contiguous() - else: - expert_slice = slice_[start:stop] - tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( - dtype=weights.dtype - ).to(device=weights.device) - return tensor - - -class BlockSparseMoE(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size // weights.process_group.size() - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - - act = config.hidden_act - if "gelu" in act: - self.act = lambda x: torch.nn.functional.gelu( - x, - approximate=( - "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" - ), - ) - elif "silu" in act: - self.act = torch.nn.functional.silu - else: - self.act = ACT2FN[act] - - # gating - self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) - - # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view( - self.num_experts, self.ffn_dim, self.hidden_dim - ) - w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view( - self.num_experts, self.ffn_dim, self.hidden_dim - ) - self.w13 = torch.cat([w1, w3], dim=1) - self.w2 = ( - _load_experts(config, f"{prefix}.experts", "w2", weights) - .view(self.num_experts, self.ffn_dim, self.hidden_dim) - .transpose(1, 2) - .contiguous() - ) - - self.process_group = weights.process_group - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # router_logits: (num_tokens, n_experts) - router_logits = self.gate(x) - out = fused_moe( - x, - self.w13, - self.w2, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - ) - - # Reduce sum - if self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) - - return out.view(*x.shape) - - class FlashPhiAttention(torch.nn.Module): def __init__( self, @@ -233,12 +118,7 @@ class FlashPhiAttention(torch.nn.Module): self.head_size = self.hidden_size // self.num_heads self.softmax_scale = self.head_size**-0.5 - self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct" - self.head_dim = self.head_size - if hasattr(config, "partial_rotary_factor"): - self.rotary_dim = int(config.partial_rotary_factor * self.head_size) - else: - self.rotary_dim = self.head_size + self.rotary_dim = int(config.partial_rotary_factor * self.head_size) self.rotary_emb = PositionRotaryEmbedding.static( config=config, @@ -261,10 +141,9 @@ class FlashPhiAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) # in llama the dense layer is called "o_proj" and has bias=False - proj_layer_name = f"{prefix}.o_proj" if self.use_moe else f"{prefix}.dense" self.dense = TensorParallelRowLinear.load( config, - prefix=proj_layer_name, + prefix=f"{prefix}.dense", weights=weights, bias=True, ) @@ -304,14 +183,9 @@ class FlashPhiAttention(torch.nn.Module): # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions # # Apply partial positional embeddings in place - if self.use_moe: - # rotate half rotary_dim - # half_size = torch.select(kv, dim=1, index=0).size(1) // 2 - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - else: - self.rotary_emb( - query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin - ) + self.rotary_emb( + query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin + ) # Reshape key and value and cache reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) @@ -384,30 +258,13 @@ class FlashPhiLayer(nn.Module): self.self_attn = FlashPhiAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) - - self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct" - - if self.use_moe: - self.moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) - else: - self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - + self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, - # eps=config.layer_norm_eps, - eps=1e-5, + eps=config.layer_norm_eps, ) - - if self.use_moe: - self.post_attn_layernorm = FastLayerNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=1e-5, - ) - - # self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) - self.resid_dropout = torch.nn.Dropout(config.attention_dropout) + self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) def forward( self, @@ -436,28 +293,9 @@ class FlashPhiLayer(nn.Module): max_s, ) - - if self.use_moe: - # hidden_states = attn_output - # if residual is not None: - # hidden_states = hidden_states + residual - # residual = hidden_states - # hidden_states, router_states = self.post_attn_layernorm( - # hidden_states - # ) - # hidden_states = self.moe(hidden_states) - # # hidden_states = hidden_states + residual - # res = residual - - hidden_states = self.resid_dropout(attn_output) - _hidden_states = self.resid_dropout(self.moe(hidden_states)) - hidden_states = hidden_states.add(_hidden_states) - - - else: - hidden_states = self.resid_dropout(attn_output).add( - self.resid_dropout(self.mlp(hidden_states)) - ) + hidden_states = self.resid_dropout(attn_output).add( + self.resid_dropout(self.mlp(hidden_states)) + ) return hidden_states, res @@ -489,20 +327,11 @@ class FlashPhiModel(torch.nn.Module): self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads - self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct" - - if self.use_moe: - self.norm = FastLayerNorm.load( - prefix="model.norm", - weights=weights, - eps=1e-5, - ) - else: - self.norm = FastLayerNorm.load( - prefix="model.final_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) + self.norm = FastLayerNorm.load( + prefix="model.final_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) def forward( self,