diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index f0270ff8..76e92485 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -302,8 +302,7 @@ class HybridFP8UnquantLoader(WeightsLoader): for p in prefixes ] scale = torch.cat(scale, dim=dim) - if scale.device == torch.device("cpu"): - scale = scale.to(weights.device) + scale = scale.to(weights.device) return Fp8Weight( weight=w, weight_scale=scale, @@ -358,6 +357,7 @@ class HybridFP8UnquantLoader(WeightsLoader): # FP8 branch if w.dtype == torch.float8_e4m3fn: if self.weight_block_size is not None: + # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems. scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1) return Fp8Weight( @@ -504,6 +504,10 @@ class Fp8Linear(torch.nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: if self.weight_block_size is not None: + # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and + # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we + # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output + # channels). qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) output = w8a8_block_fp8_matmul( qinput, diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index fb285bfe..23d0d38c 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -52,6 +52,8 @@ class MoELayer(Protocol): up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", hidden_act: str = "silu", + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ): ... def forward( @@ -81,9 +83,14 @@ class DenseMoELayer(nn.Module): up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", hidden_act: str = "silu", + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ): super().__init__() + assert scoring_func is None, "scoring func is not handled" + assert e_score_correction_bias is None, "scoring correction bias is not handled" + log_once( logger.info, "No fused layers are available for this model type, using (slower) dense MoE layer",