mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
format codfe'
This commit is contained in:
parent
57de05b0dd
commit
05d0aa678e
@ -1,11 +1,11 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple, Type, Union, List
|
from typing import Optional, Tuple, Type, Union, List
|
||||||
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from moe_kernels import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
Weight,
|
Weight,
|
||||||
@ -187,8 +187,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
if self.weight_block_size is not None:
|
if self.weight_block_size is not None:
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
||||||
if scale.device == torch.device("cpu"):
|
|
||||||
scale = scale.to(weights.device)
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -289,9 +287,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
||||||
for p in prefixes
|
for p in prefixes
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=dim)
|
scale = torch.cat(scale, dim=dim).to(weights.device)
|
||||||
if scale.device == torch.device("cpu"):
|
|
||||||
scale = scale.to(weights.device)
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -347,8 +343,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
if self.weight_block_size is not None:
|
if self.weight_block_size is not None:
|
||||||
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||||
if scale.device == torch.device("cpu"):
|
|
||||||
scale = scale.to(weights.device)
|
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
|
@ -60,6 +60,7 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
return _FP8QuantizerConfig(
|
return _FP8QuantizerConfig(
|
||||||
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||||
)
|
)
|
||||||
|
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
||||||
|
|
||||||
if "zero_point" in data["quantization_config"]:
|
if "zero_point" in data["quantization_config"]:
|
||||||
sym = not data["quantization_config"]["zero_point"]
|
sym = not data["quantization_config"]["zero_point"]
|
||||||
@ -67,16 +68,12 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
elif "sym" in data["quantization_config"]:
|
elif "sym" in data["quantization_config"]:
|
||||||
sym = data["quantization_config"]["sym"]
|
sym = data["quantization_config"]["sym"]
|
||||||
|
|
||||||
if "bits" in data["quantization_config"]:
|
|
||||||
bits = data["quantization_config"]["bits"]
|
bits = data["quantization_config"]["bits"]
|
||||||
if "group_size" in data["quantization_config"]:
|
|
||||||
groupsize = data["quantization_config"]["group_size"]
|
groupsize = data["quantization_config"]["group_size"]
|
||||||
# Order is important here, desc_act is missing on some real models
|
# Order is important here, desc_act is missing on some real models
|
||||||
quant_method = data["quantization_config"]["quant_method"]
|
quant_method = data["quantization_config"]["quant_method"]
|
||||||
checkpoint_format = data["quantization_config"].get("checkpoint_format", None)
|
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||||
if desc_act in data["quantization_config"]:
|
|
||||||
desc_act = data["quantization_config"]["desc_act"]
|
desc_act = data["quantization_config"]["desc_act"]
|
||||||
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
|
@ -149,10 +149,6 @@ class Weights:
|
|||||||
):
|
):
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
# if filename.as_posix().endswith("l.safetensors"):
|
|
||||||
# from loguru import logger
|
|
||||||
# logger.info(f"Skipping {filename}")
|
|
||||||
# continue
|
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
if k in routing:
|
if k in routing:
|
||||||
|
Loading…
Reference in New Issue
Block a user