Fixing Mixtral + Nits.

This commit is contained in:
Nicolas Patry 2025-01-30 16:09:15 +01:00
parent f56e24b346
commit 51bc8a4e45
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
2 changed files with 13 additions and 2 deletions

View File

@ -302,7 +302,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
for p in prefixes for p in prefixes
] ]
scale = torch.cat(scale, dim=dim) scale = torch.cat(scale, dim=dim)
if scale.device == torch.device("cpu"):
scale = scale.to(weights.device) scale = scale.to(weights.device)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
@ -358,6 +357,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None: if self.weight_block_size is not None:
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1) scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
return Fp8Weight( return Fp8Weight(
@ -504,6 +504,10 @@ class Fp8Linear(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None: if self.weight_block_size is not None:
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul( output = w8a8_block_fp8_matmul(
qinput, qinput,

View File

@ -52,6 +52,8 @@ class MoELayer(Protocol):
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
hidden_act: str = "silu", hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ... ): ...
def forward( def forward(
@ -81,9 +83,14 @@ class DenseMoELayer(nn.Module):
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
hidden_act: str = "silu", hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ):
super().__init__() super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once( log_once(
logger.info, logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer", "No fused layers are available for this model type, using (slower) dense MoE layer",