mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
format codfe'
This commit is contained in:
parent
57de05b0dd
commit
05d0aa678e
@ -1,11 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Optional, Tuple, Type, Union, List
|
||||
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from moe_kernels import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import (
|
||||
Weight,
|
||||
@ -187,8 +187,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_block_size is not None:
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
||||
if scale.device == torch.device("cpu"):
|
||||
scale = scale.to(weights.device)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -289,9 +287,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=dim)
|
||||
if scale.device == torch.device("cpu"):
|
||||
scale = scale.to(weights.device)
|
||||
scale = torch.cat(scale, dim=dim).to(weights.device)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -347,8 +343,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_block_size is not None:
|
||||
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||
if scale.device == torch.device("cpu"):
|
||||
scale = scale.to(weights.device)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
|
@ -60,6 +60,7 @@ def _get_quantizer_config(model_id, revision):
|
||||
return _FP8QuantizerConfig(
|
||||
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||
)
|
||||
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
||||
|
||||
if "zero_point" in data["quantization_config"]:
|
||||
sym = not data["quantization_config"]["zero_point"]
|
||||
@ -67,16 +68,12 @@ def _get_quantizer_config(model_id, revision):
|
||||
elif "sym" in data["quantization_config"]:
|
||||
sym = data["quantization_config"]["sym"]
|
||||
|
||||
if "bits" in data["quantization_config"]:
|
||||
bits = data["quantization_config"]["bits"]
|
||||
if "group_size" in data["quantization_config"]:
|
||||
groupsize = data["quantization_config"]["group_size"]
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
quant_method = data["quantization_config"]["quant_method"]
|
||||
checkpoint_format = data["quantization_config"].get("checkpoint_format", None)
|
||||
if desc_act in data["quantization_config"]:
|
||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||
desc_act = data["quantization_config"]["desc_act"]
|
||||
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
|
@ -149,10 +149,6 @@ class Weights:
|
||||
):
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
# if filename.as_posix().endswith("l.safetensors"):
|
||||
# from loguru import logger
|
||||
# logger.info(f"Skipping {filename}")
|
||||
# continue
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
|
Loading…
Reference in New Issue
Block a user