mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing Mixtral + Nits.
This commit is contained in:
parent
f56e24b346
commit
51bc8a4e45
@ -302,7 +302,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=dim)
|
||||
if scale.device == torch.device("cpu"):
|
||||
scale = scale.to(weights.device)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
@ -358,6 +357,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_block_size is not None:
|
||||
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
|
||||
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||
|
||||
return Fp8Weight(
|
||||
@ -504,6 +504,10 @@ class Fp8Linear(torch.nn.Module):
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.weight_block_size is not None:
|
||||
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
||||
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
||||
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
||||
# channels).
|
||||
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||
output = w8a8_block_fp8_matmul(
|
||||
qinput,
|
||||
|
@ -52,6 +52,8 @@ class MoELayer(Protocol):
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
hidden_act: str = "silu",
|
||||
scoring_func: Optional[str] = None,
|
||||
e_score_correction_bias: Optional[float] = None,
|
||||
): ...
|
||||
|
||||
def forward(
|
||||
@ -81,9 +83,14 @@ class DenseMoELayer(nn.Module):
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
hidden_act: str = "silu",
|
||||
scoring_func: Optional[str] = None,
|
||||
e_score_correction_bias: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert scoring_func is None, "scoring func is not handled"
|
||||
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
||||
|
||||
log_once(
|
||||
logger.info,
|
||||
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||
|
Loading…
Reference in New Issue
Block a user