diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 78f03511..44d30202 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -409,6 +409,66 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = [ + weights.get_tensor(f"{p}.weight_scale_inv", to_device=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=dim) + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + + scale = [ + weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1) + for p in prefixes + ] + scale = torch.cat(scale, dim=0).reshape(-1) + + input_scale = [ + weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1) + for p in prefixes + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + + return Fp8Weight( + weight=w, + weight_scale=scale, + input_scale=input_scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + def get_weights_row(self, weights: "Weights", prefix: str): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py index 5362e8de..5365f24f 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py @@ -2,6 +2,7 @@ from typing import Optional import torch import torch.nn as nn +import os from text_generation_server.utils.weights import Weights from text_generation_server.layers.fp8 import ( @@ -46,6 +47,16 @@ class FP8SparseMoELayer(nn.Module): self.weight_block_size = weights.weights_loader.weight_block_size self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() + self.ep_rank = self.rank + self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true" + + if self.use_ep: + n_experts = (n_experts + self.world_size - 1) // self.world_size + self.ep_offset = self.ep_rank * n_experts + else: + self.ep_offset = 0 ( self.gate_up_proj, @@ -57,6 +68,8 @@ class FP8SparseMoELayer(nn.Module): gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, weights=weights, + use_ep=self.use_ep, + ep_offset=self.ep_offset, ) self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( @@ -65,6 +78,8 @@ class FP8SparseMoELayer(nn.Module): n_experts=n_experts, name=down_proj_name, weights=weights, + use_ep=self.use_ep, + ep_offset=self.ep_offset, ) ) if self.weight_block_size is not None: @@ -99,8 +114,15 @@ class FP8SparseMoELayer(nn.Module): ) total_num_experts = gating_output.size(-1) x_fp8, x_scale = dynamic_quant(x, single_scale=True) - moe_n_slice = (total_num_experts + 31) // 32 - n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice + + if self.use_ep: + moe_n_slice = 1 + n_expert_slice = ( + total_num_experts + self.world_size - 1 + ) // self.world_size + else: + moe_n_slice = 1 + n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice for i in range(moe_n_slice): min_expert = i * n_expert_slice max_expert = min((i + 1) * n_expert_slice, total_num_experts) @@ -130,8 +152,8 @@ class FP8SparseMoELayer(nn.Module): d_scale_w3=w2_weight_scale, permuted_weights=True, activation="silu", - experts_min=min_expert, - experts_max=max_expert - 1, + experts_min=min_expert + self.ep_offset, + experts_max=max_expert + self.ep_offset - 1, ) htorch.core.mark_step() if i == 0: @@ -148,13 +170,14 @@ def _load_expert_weights( n_experts: int, name: str, weights: Weights, + ep_offset: int = 0, ) -> torch.Tensor: all_weight = None all_weight_scales = None max_input_scale = None for i in range(n_experts): - weight = get_weight_fn(prefix, i, name, weights) + weight = get_weight_fn(prefix, i + ep_offset, name, weights) assert isinstance(weight, Fp8Weight) @@ -197,14 +220,26 @@ def _load_expert_multi_weights_col( gate_proj_name: str, up_proj_name: str, weights: Weights, + use_ep: bool = False, + ep_offset: int = 0, ) -> torch.Tensor: - def get_weight_fn(prefix, i, name, weights): + def get_weight_fn_sharded(prefix, i, name, weights): return weights.get_multi_weights_col( [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 ) + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + return _load_expert_weights( - get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + get_weight_fn if use_ep else get_weight_fn_sharded, + prefix=prefix, + n_experts=n_experts, + name=None, + weights=weights, + ep_offset=ep_offset if use_ep else 0, ) @@ -214,10 +249,20 @@ def _load_expert_weights_row( n_experts: int, name: str, weights: Weights, + use_ep: bool = False, + ep_offset: int = 0, ) -> torch.Tensor: - def get_weight_fn(prefix, i, name, weights): + def get_weight_fn_sharded(prefix, i, name, weights): return weights.get_weights_row(f"{prefix}.{i}.{name}") + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights(f"{prefix}.{i}.{name}") + return _load_expert_weights( - get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + get_weight_fn if use_ep else get_weight_fn_sharded, + prefix=prefix, + n_experts=n_experts, + name=name, + weights=weights, + ep_offset=ep_offset if use_ep else 0, ) diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index acd598d7..da936d36 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -62,6 +62,14 @@ class WeightsLoader(ABC): """ ... + @abstractmethod + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + @abstractmethod def get_weights_row(self, weights: "Weights", prefix: str): """ @@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader): weights.get_sharded(f"{prefix}.weight", dim=1), ) + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_tensor(f"{p}.weight") for p in prefixes] + return self.weight_class(torch.cat(w, dim=dim)) + class Weights: def __init__( @@ -390,6 +402,9 @@ class Weights: def get_weights_row(self, prefix: str): return self.weights_loader.get_weights_row(self, prefix) + def get_multi_weights(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights(self, prefixes, dim) + @contextmanager def use_loader(self, weights_loader: WeightsLoader): """