format codfe'

This commit is contained in:
Mohit Sharma 2025-01-29 10:39:09 +00:00
parent 57de05b0dd
commit 05d0aa678e
3 changed files with 7 additions and 20 deletions

View File

@ -1,11 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
import os import os
from typing import Optional, Tuple, Type, Union, List from typing import Optional, Tuple, Type, Union, List
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
import torch import torch
from loguru import logger from loguru import logger
from moe_kernels import w8a8_block_fp8_matmul, per_token_group_quant_fp8
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
Weight, Weight,
@ -187,8 +187,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None: if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv") scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
if scale.device == torch.device("cpu"):
scale = scale.to(weights.device)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -289,9 +287,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False) weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes for p in prefixes
] ]
scale = torch.cat(scale, dim=dim) scale = torch.cat(scale, dim=dim).to(weights.device)
if scale.device == torch.device("cpu"):
scale = scale.to(weights.device)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -347,8 +343,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None: if self.weight_block_size is not None:
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1) scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
if scale.device == torch.device("cpu"):
scale = scale.to(weights.device)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,

View File

@ -60,6 +60,7 @@ def _get_quantizer_config(model_id, revision):
return _FP8QuantizerConfig( return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"] activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
) )
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]: if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"] sym = not data["quantization_config"]["zero_point"]
@ -67,16 +68,12 @@ def _get_quantizer_config(model_id, revision):
elif "sym" in data["quantization_config"]: elif "sym" in data["quantization_config"]:
sym = data["quantization_config"]["sym"] sym = data["quantization_config"]["sym"]
if "bits" in data["quantization_config"]:
bits = data["quantization_config"]["bits"] bits = data["quantization_config"]["bits"]
if "group_size" in data["quantization_config"]:
groupsize = data["quantization_config"]["group_size"] groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"] quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format", None) checkpoint_format = data["quantization_config"].get("checkpoint_format")
if desc_act in data["quantization_config"]:
desc_act = data["quantization_config"]["desc_act"] desc_act = data["quantization_config"]["desc_act"]
weight_block_size = data["quantization_config"].get("weight_block_size", None)
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:

View File

@ -149,10 +149,6 @@ class Weights:
): ):
routing = {} routing = {}
for filename in filenames: for filename in filenames:
# if filename.as_posix().endswith("l.safetensors"):
# from loguru import logger
# logger.info(f"Skipping {filename}")
# continue
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
for k in f.keys(): for k in f.keys():
if k in routing: if k in routing: