mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
refactored weights loader
This commit is contained in:
parent
081d16cab5
commit
6a93a24f3f
@ -8,6 +8,7 @@ from typing import Optional
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
@ -87,15 +88,17 @@ def serve(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
if len(lora_adapter_ids) > 0:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
logger.warning,
|
||||||
|
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||||
# and warn the user
|
# and warn the user
|
||||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs."
|
logger.warning,
|
||||||
|
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
||||||
)
|
)
|
||||||
global CUDA_GRAPHS
|
global CUDA_GRAPHS
|
||||||
CUDA_GRAPHS = None
|
CUDA_GRAPHS = None
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
@ -136,7 +137,10 @@ if ENGINE != "triton":
|
|||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
|
||||||
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
if major >= 8:
|
if major >= 8:
|
||||||
architecture_suffix = f"-{SYSTEM}"
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
@ -4,19 +4,11 @@ from functools import lru_cache
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
from bitsandbytes.nn import Int8Params, Params4bit
|
from bitsandbytes.nn import Int8Params, Params4bit
|
||||||
from loguru import logger
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
from text_generation_server.utils.weights import Weight
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(1)
|
|
||||||
def warn_deprecate_bnb():
|
|
||||||
logger.warning(
|
|
||||||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBWeight(Weight):
|
class BNBWeight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBFP4Weight(Weight):
|
class BNBFP4Weight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
@ -90,7 +82,7 @@ class BNBFP4Weight(Weight):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBNF4Weight(Weight):
|
class BNBNF4Weight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -2,11 +2,11 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from EETQ import quant_weights, w8_a16_gemm
|
from EETQ import quant_weights, w8_a16_gemm
|
||||||
from text_generation_server.utils.weights import Weight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EETQWeight(Weight):
|
class EETQWeight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -1,19 +1,29 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional, Union, List
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
Weight,
|
||||||
|
WeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.log import log_master, log_once
|
||||||
|
|
||||||
|
FBGEMM_MM_AVAILABLE = False
|
||||||
|
FBGEMM_DYN_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai
|
import fbgemm_gpu.experimental.gen_ai
|
||||||
|
|
||||||
major, _ = torch.cuda.get_device_capability()
|
if SYSTEM == "cuda":
|
||||||
HAS_FBGEMM_MM = major == 9
|
major, _ = torch.cuda.get_device_capability()
|
||||||
HAS_FBGEMM_DYN = major >= 8
|
FBGEMM_MM_AVAILABLE = major == 9
|
||||||
|
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
HAS_FBGEMM_MM = False
|
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||||
HAS_FBGEMM_DYN = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_linear() -> torch.nn.Module:
|
def get_fp8_linear() -> torch.nn.Module:
|
||||||
@ -33,7 +43,7 @@ def get_fp8_linear() -> torch.nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
||||||
if HAS_FBGEMM_DYN:
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||||
)
|
)
|
||||||
@ -54,18 +64,83 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
|||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
|
|
||||||
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
|
def __init__(self, activation_scale_ub: Optional[float]):
|
||||||
|
self.activation_scale_ub = activation_scale_ub
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
w = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||||
|
)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
|
w = torch.cat(w, dim=dim)
|
||||||
|
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=0)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Weight:
|
class Fp8Weight(Weight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
weight_scale: Optional[torch.Tensor] = None
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
input_scale: Optional[torch.Tensor] = None
|
activation_scale_ub: Optional[float] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear().from_fp8(
|
||||||
self.weight, self.weight_scale, self.input_scale, bias, self.dtype
|
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +157,13 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.scale_upper_bound = scale_upper_bound
|
self.scale_upper_bound = (
|
||||||
|
torch.tensor(
|
||||||
|
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
||||||
|
)
|
||||||
|
if scale_upper_bound is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
@ -104,7 +185,9 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if HAS_FBGEMM_MM:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
|
log_once(logger.info, "Using FBGEMM fp8 kernels")
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input, scale_upper_bound=self.scale_upper_bound
|
input, scale_upper_bound=self.scale_upper_bound
|
||||||
)
|
)
|
||||||
|
@ -9,11 +9,12 @@ from loguru import logger
|
|||||||
|
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("exllamav2_kernels not installed.")
|
log_master(logger.warning, "exllamav2_kernels not installed.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
|
@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
# in PyTorch 1.12 and later.
|
# in PyTorch 1.12 and later.
|
||||||
@ -47,9 +48,7 @@ torch.set_grad_enabled(False)
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Model",
|
"Model",
|
||||||
"BLOOMSharded",
|
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"GalacticaSharded",
|
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"get_model",
|
"get_model",
|
||||||
]
|
]
|
||||||
@ -125,7 +124,7 @@ try:
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
@ -137,7 +136,7 @@ MAMBA_AVAILABLE = True
|
|||||||
try:
|
try:
|
||||||
from text_generation_server.models.mamba import Mamba
|
from text_generation_server.models.mamba import Mamba
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Mamba: {e}")
|
log_master(logger.warning, f"Could not import Mamba: {e}")
|
||||||
MAMBA_AVAILABLE = False
|
MAMBA_AVAILABLE = False
|
||||||
|
|
||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
@ -312,8 +311,11 @@ def get_model(
|
|||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
elif quantize == "fp8":
|
elif quantize == "fp8":
|
||||||
# gemm kernels are fp8xfp8->bf16
|
from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
if FBGEMM_MM_AVAILABLE:
|
||||||
|
# fbgemm kernels are fp8xfp8->bf16
|
||||||
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
# Keep it as default for now and let
|
# Keep it as default for now and let
|
||||||
# every model resolve their own default dtype.
|
# every model resolve their own default dtype.
|
||||||
@ -436,7 +438,9 @@ def get_model(
|
|||||||
|
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
log_master(
|
||||||
|
logger.info, f"Using speculation {method} with {speculate} input ids."
|
||||||
|
)
|
||||||
|
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
# TODO: fix how we determine model type for Mamba
|
# TODO: fix how we determine model type for Mamba
|
||||||
@ -451,10 +455,10 @@ def get_model(
|
|||||||
if quantization_config is not None and quantize is None:
|
if quantization_config is not None and quantize is None:
|
||||||
method = quantization_config.get("quant_method", None)
|
method = quantization_config.get("quant_method", None)
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
logger.info(f"Auto selecting quantization method {method}")
|
log_master(logger.info, f"Auto selecting quantization method {method}")
|
||||||
quantize = method
|
quantize = method
|
||||||
else:
|
else:
|
||||||
logger.info(f"Unknown quantization method {method}")
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
|
||||||
if quantize == "exl2" and sharded:
|
if quantize == "exl2" and sharded:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -596,7 +600,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
||||||
return CausalLM.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -418,7 +418,22 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
|
# Skip fp8 quant for first and last layers
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
with no_fp8(weights):
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=0,
|
||||||
|
prefix=(
|
||||||
|
"model.layers.0" if not prefix else "{prefix}.model.layers.0"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers.extend(
|
||||||
[
|
[
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
@ -430,9 +445,26 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
# Skip first and last layers
|
||||||
|
for layer_id in range(1, config.num_hidden_layers - 1)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with no_fp8(weights):
|
||||||
|
last_layer_id = config.num_hidden_layers - 1
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=last_layer_id,
|
||||||
|
prefix=(
|
||||||
|
f"model.layers.{last_layer_id}"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.model.layers.{last_layer_id}"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.dist import RANK
|
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
hub,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
@ -1156,31 +1155,36 @@ class FlashCausalLM(Model):
|
|||||||
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
|
logger.info,
|
||||||
|
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile(tunableop_filepath):
|
if os.path.isfile(tunableop_filepath):
|
||||||
logger.info(
|
log_master(
|
||||||
f"The file {tunableop_filepath} already exists and will be reused."
|
logger.info,
|
||||||
|
f"The file {tunableop_filepath} already exists and will be reused.",
|
||||||
)
|
)
|
||||||
torch.cuda.tunable.read_file(tunableop_filepath)
|
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||||
|
|
||||||
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||||
|
|
||||||
for seqlen in tuning_sequences:
|
for seqlen in tuning_sequences:
|
||||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
self.tunableop_warmup(seqlen)
|
self.tunableop_warmup(seqlen)
|
||||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
log_master(
|
||||||
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
|
logger.info,
|
||||||
|
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
log_master(
|
||||||
|
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
|
||||||
|
)
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in CUDA_GRAPHS:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
@ -1188,7 +1192,9 @@ class FlashCausalLM(Model):
|
|||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
log_master(
|
||||||
|
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
||||||
|
)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
@ -1540,8 +1546,7 @@ class FlashCausalLM(Model):
|
|||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
if n_accepted_ids > 1:
|
if n_accepted_ids > 1:
|
||||||
if RANK == 0:
|
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
||||||
logger.debug(f"Speculated ids {n_accepted_ids - 1}")
|
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
logger.info("Using FLASH_DECODING")
|
log_master(logger.info, "Using FLASH_DECODING")
|
||||||
|
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
@ -26,11 +27,9 @@ else:
|
|||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
cuda_graphs.sort(reverse=True)
|
cuda_graphs.sort(reverse=True)
|
||||||
|
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
|
||||||
# This is overridden at model loading.
|
# This is overridden at model loading.
|
||||||
global MODEL_ID
|
|
||||||
MODEL_ID = None
|
MODEL_ID = None
|
||||||
|
|
||||||
|
|
||||||
@ -41,8 +40,7 @@ def set_model_id(model_id: str):
|
|||||||
|
|
||||||
# NOTE: eventually we should move this into the router and pass back the
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
global ADAPTER_TO_INDEX
|
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||||
ADAPTER_TO_INDEX: Dict[str, int] = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import (
|
|||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
AdapterSource,
|
AdapterSource,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
@ -204,8 +205,9 @@ class Model(ABC):
|
|||||||
f"order to use the dynamic adapter loading feature."
|
f"order to use the dynamic adapter loading feature."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
logger.info,
|
||||||
|
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
|
||||||
)
|
)
|
||||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||||
(
|
(
|
||||||
@ -240,8 +242,9 @@ class Model(ABC):
|
|||||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||||
|
|
||||||
if len(unused_weight_names) > 0:
|
if len(unused_weight_names) > 0:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
logger.warning,
|
||||||
|
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if adapter_tokenizer is not None:
|
if adapter_tokenizer is not None:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from itertools import repeat
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import (
|
|||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"Found {num_features} features in image of resolution {height}x{width}"
|
logger.info,
|
||||||
|
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||||
)
|
)
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features
|
||||||
|
|
||||||
|
@ -1,6 +1,15 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from text_generation_server.utils.dist import RANK
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
@lru_cache(10)
|
||||||
def log_once(log, msg: str):
|
def log_once(log, msg: str, master=True):
|
||||||
log(msg)
|
if master:
|
||||||
|
log_master(log, msg)
|
||||||
|
else:
|
||||||
|
log(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def log_master(log, msg: str):
|
||||||
|
if RANK == 0:
|
||||||
|
log(msg)
|
||||||
|
@ -11,6 +11,7 @@ from text_generation_server.utils.weights import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Split this config to have a single config type per quant method
|
||||||
@dataclass
|
@dataclass
|
||||||
class _QuantizerConfig:
|
class _QuantizerConfig:
|
||||||
bits: int
|
bits: int
|
||||||
@ -21,6 +22,11 @@ class _QuantizerConfig:
|
|||||||
sym: bool
|
sym: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FP8QuantizerConfig:
|
||||||
|
activation_scale_ub: float
|
||||||
|
|
||||||
|
|
||||||
# We should probably do this with Pytantic JSON deserialization,
|
# We should probably do this with Pytantic JSON deserialization,
|
||||||
# but for now we'll stay close to the old _set_gptq_params.
|
# but for now we'll stay close to the old _set_gptq_params.
|
||||||
def _get_quantizer_config(model_id, revision):
|
def _get_quantizer_config(model_id, revision):
|
||||||
@ -63,6 +69,17 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
desc_act = data["desc_act"]
|
desc_act = data["desc_act"]
|
||||||
if "version" in data and data["version"] == "GEMM":
|
if "version" in data and data["version"] == "GEMM":
|
||||||
quant_method = "awq"
|
quant_method = "awq"
|
||||||
|
# FP8 config
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
filename = os.path.join(model_id, filename)
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return _FP8QuantizerConfig(
|
||||||
|
activation_scale_ub=data["activation_scale_ub"]
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quant_config.json"
|
filename = "quant_config.json"
|
||||||
try:
|
try:
|
||||||
@ -99,6 +116,12 @@ def get_loader(
|
|||||||
if quantize in {"awq", "gptq"}:
|
if quantize in {"awq", "gptq"}:
|
||||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
|
|
||||||
|
# TODO: improve check once we have one config type per quantize value
|
||||||
|
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||||
|
)
|
||||||
|
|
||||||
return GPTQWeightsLoader(
|
return GPTQWeightsLoader(
|
||||||
bits=quantizer_config.bits,
|
bits=quantizer_config.bits,
|
||||||
desc_act=quantizer_config.desc_act,
|
desc_act=quantizer_config.desc_act,
|
||||||
@ -127,18 +150,28 @@ def get_loader(
|
|||||||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
||||||
|
|
||||||
return Exl2WeightsLoader()
|
return Exl2WeightsLoader()
|
||||||
elif quantize == "fp8":
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
|
|
||||||
return DefaultWeightsLoader(Fp8Weight)
|
|
||||||
elif quantize == "marlin":
|
elif quantize == "marlin":
|
||||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
||||||
|
|
||||||
|
# TODO: improve check once we have one config type per quantize value
|
||||||
|
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||||
|
)
|
||||||
|
|
||||||
return MarlinWeightsLoader(
|
return MarlinWeightsLoader(
|
||||||
bits=quantizer_config.bits,
|
bits=quantizer_config.bits,
|
||||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||||
)
|
)
|
||||||
elif quantize is None:
|
elif quantize == "fp8" or quantize is None:
|
||||||
return DefaultWeightsLoader()
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
|
# Since the default for the quantize config is _QuantizerConfig,
|
||||||
|
# we need to add this check to not get an attribute error
|
||||||
|
activation_scale_ub = None
|
||||||
|
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||||
|
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||||
|
|
||||||
|
return HybridFP8UnquantLoader(activation_scale_ub)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||||
|
@ -3,7 +3,7 @@ import torch
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union, Type
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -84,9 +84,8 @@ class Weight(ABC):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnquantizedWeight:
|
class UnquantizedWeight(Weight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
dtype: torch.dtype
|
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
||||||
@ -100,7 +99,7 @@ class UnquantizedWeight:
|
|||||||
class DefaultWeightsLoader(WeightsLoader):
|
class DefaultWeightsLoader(WeightsLoader):
|
||||||
"""Weight loader that loads (unquantized) Torch tensors."""
|
"""Weight loader that loads (unquantized) Torch tensors."""
|
||||||
|
|
||||||
def __init__(self, weight_class: Optional = None):
|
def __init__(self, weight_class: Type[UnquantizedWeight]):
|
||||||
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
||||||
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
||||||
such as `Fp8Weight` can be used to quantize the weights during loading.
|
such as `Fp8Weight` can be used to quantize the weights during loading.
|
||||||
@ -121,92 +120,21 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
w = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
return self.weight_class(
|
||||||
|
weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
|
||||||
# FIXME: here to avoid circular import
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
|
|
||||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
|
||||||
)
|
|
||||||
# FIXME: here to avoid circular import
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
|
|
||||||
# FP8 branch
|
|
||||||
scale = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, cast=False
|
|
||||||
)
|
|
||||||
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
|
|
||||||
return Fp8Weight(
|
|
||||||
weight=w,
|
|
||||||
weight_scale=scale,
|
|
||||||
input_scale=input_scale,
|
|
||||||
dtype=weights.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.weight_class is None:
|
|
||||||
return UnquantizedWeight(w, dtype=weights.dtype)
|
|
||||||
return self.weight_class(w, dtype=weights.dtype)
|
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
w = torch.cat(w, dim=dim)
|
return self.weight_class(torch.cat(w, dim=dim))
|
||||||
|
|
||||||
# FP8 branch
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
|
||||||
# FIXME: here to avoid circular import
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
|
|
||||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
|
||||||
)
|
|
||||||
|
|
||||||
scale = [
|
|
||||||
weights.get_sharded(f"{p}.weight_scale", dim=0, cast=False)
|
|
||||||
for p in prefixes
|
|
||||||
]
|
|
||||||
scale = torch.cat(scale, dim=0)
|
|
||||||
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale", cast=False)
|
|
||||||
return Fp8Weight(
|
|
||||||
weight=w,
|
|
||||||
weight_scale=scale,
|
|
||||||
input_scale=input_scale,
|
|
||||||
dtype=weights.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.weight_class is None:
|
|
||||||
return UnquantizedWeight(w, dtype=weights.dtype)
|
|
||||||
return self.weight_class(w, dtype=weights.dtype)
|
|
||||||
|
|
||||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
return self.weight_class(
|
||||||
# FP8 branch
|
weights.get_sharded(f"{prefix}.weight", dim=1),
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
)
|
||||||
# FIXME: here to avoid circular import
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
|
|
||||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
|
||||||
)
|
|
||||||
|
|
||||||
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, cast=False)
|
|
||||||
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
|
|
||||||
return Fp8Weight(
|
|
||||||
weight=w,
|
|
||||||
weight_scale=scale,
|
|
||||||
input_scale=input_scale,
|
|
||||||
dtype=weights.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.weight_class is None:
|
|
||||||
return UnquantizedWeight(w, dtype=weights.dtype)
|
|
||||||
return self.weight_class(w, dtype=weights.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
@ -280,7 +208,7 @@ class Weights:
|
|||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str, to_device=True, cast=True):
|
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
@ -295,14 +223,14 @@ class Weights:
|
|||||||
torch.int32,
|
torch.int32,
|
||||||
torch.int64,
|
torch.int64,
|
||||||
]
|
]
|
||||||
and cast
|
and to_dtype
|
||||||
):
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
if to_device:
|
if to_device:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_partial_sharded(self, tensor_name: str, dim: int, cast=True):
|
def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
@ -323,12 +251,15 @@ class Weights:
|
|||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. exl2 uses int16.
|
# u4 which are disguised as int32. exl2 uses int16.
|
||||||
# FP8 uses torch.float8_e4m3fn.
|
# FP8 uses torch.float8_e4m3fn.
|
||||||
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) and cast:
|
if (
|
||||||
|
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
|
||||||
|
and to_dtype
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int, cast=True):
|
def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
@ -337,10 +268,14 @@ class Weights:
|
|||||||
assert (
|
assert (
|
||||||
size % world_size == 0
|
size % world_size == 0
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim, cast=cast)
|
return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
|
||||||
|
|
||||||
def get_packed_sharded(
|
def get_packed_sharded(
|
||||||
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]], cast=True
|
self,
|
||||||
|
tensor_name: str,
|
||||||
|
dim: int,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
to_dtype=True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Get a shard from a tensor that packs multiple tensors.
|
Get a shard from a tensor that packs multiple tensors.
|
||||||
@ -394,7 +329,7 @@ class Weights:
|
|||||||
torch.int32,
|
torch.int32,
|
||||||
torch.int64,
|
torch.int64,
|
||||||
]
|
]
|
||||||
and cast
|
and to_dtype
|
||||||
):
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user