refactored weights loader

This commit is contained in:
OlivierDehaene 2024-07-20 09:02:02 +02:00
parent 081d16cab5
commit 6a93a24f3f
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
15 changed files with 274 additions and 171 deletions

View File

@ -8,6 +8,7 @@ from typing import Optional
from enum import Enum from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.log import log_master
app = typer.Typer() app = typer.Typer()
@ -87,15 +88,17 @@ def serve(
) )
if len(lora_adapter_ids) > 0: if len(lora_adapter_ids) > 0:
logger.warning( log_master(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." logger.warning,
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
) )
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user # and warn the user
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
logger.warning( log_master(
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs." logger.warning,
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
) )
global CUDA_GRAPHS global CUDA_GRAPHS
CUDA_GRAPHS = None CUDA_GRAPHS = None

View File

@ -3,6 +3,7 @@ import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
@ -136,7 +137,10 @@ if ENGINE != "triton":
try: try:
import flash_attn_2_cuda import flash_attn_2_cuda
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") log_master(
logger.info,
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
)
except ImportError as e: except ImportError as e:
if major >= 8: if major >= 8:
architecture_suffix = f"-{SYSTEM}" architecture_suffix = f"-{SYSTEM}"

View File

@ -4,19 +4,11 @@ from functools import lru_cache
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
from bitsandbytes.nn import Int8Params, Params4bit from bitsandbytes.nn import Int8Params, Params4bit
from loguru import logger from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.weights import Weight
@lru_cache(1)
def warn_deprecate_bnb():
logger.warning(
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
)
@dataclass @dataclass
class BNBWeight(Weight): class BNBWeight(UnquantizedWeight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module):
@dataclass @dataclass
class BNBFP4Weight(Weight): class BNBFP4Weight(UnquantizedWeight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
@ -90,7 +82,7 @@ class BNBFP4Weight(Weight):
@dataclass @dataclass
class BNBNF4Weight(Weight): class BNBNF4Weight(UnquantizedWeight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):

View File

@ -2,11 +2,11 @@ from dataclasses import dataclass
import torch import torch
from EETQ import quant_weights, w8_a16_gemm from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import Weight from text_generation_server.utils.weights import UnquantizedWeight
@dataclass @dataclass
class EETQWeight(Weight): class EETQWeight(UnquantizedWeight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):

View File

@ -1,19 +1,29 @@
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Union, List
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
UnquantizedWeight,
Weights,
)
from text_generation_server.utils.log import log_master, log_once
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
try: try:
import fbgemm_gpu.experimental.gen_ai import fbgemm_gpu.experimental.gen_ai
major, _ = torch.cuda.get_device_capability() if SYSTEM == "cuda":
HAS_FBGEMM_MM = major == 9 major, _ = torch.cuda.get_device_capability()
HAS_FBGEMM_DYN = major >= 8 FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
HAS_FBGEMM_MM = False log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
HAS_FBGEMM_DYN = False
def get_fp8_linear() -> torch.nn.Module: def get_fp8_linear() -> torch.nn.Module:
@ -33,7 +43,7 @@ def get_fp8_linear() -> torch.nn.Module:
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
if HAS_FBGEMM_DYN: if FBGEMM_DYN_AVAILABLE:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
) )
@ -54,18 +64,83 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
return qweight, scale return qweight, scale
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(self, activation_scale_ub: Optional[float]):
self.activation_scale_ub = activation_scale_ub
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
w = torch.cat(w, dim=dim)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
]
scale = torch.cat(scale, dim=0)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
return UnquantizedWeight(w)
@dataclass @dataclass
class Fp8Weight: class Fp8Weight(Weight):
weight: torch.Tensor weight: torch.Tensor
dtype: torch.dtype dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None activation_scale_ub: Optional[float] = None
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None: if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
return get_fp8_linear().from_fp8( return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.input_scale, bias, self.dtype self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
) )
@ -82,7 +157,13 @@ class Fp8Linear(torch.nn.Module):
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale self.scale = scale
self.scale_upper_bound = scale_upper_bound self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
@ -104,7 +185,9 @@ class Fp8Linear(torch.nn.Module):
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if HAS_FBGEMM_MM: if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 kernels")
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound input, scale_upper_bound=self.scale_upper_bound
) )

View File

@ -9,11 +9,12 @@ from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_master
try: try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError: except ImportError:
logger.error("exllamav2_kernels not installed.") log_master(logger.warning, "exllamav2_kernels not installed.")
raise raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension

View File

@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_master
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
@ -47,9 +48,7 @@ torch.set_grad_enabled(False)
__all__ = [ __all__ = [
"Model", "Model",
"BLOOMSharded",
"CausalLM", "CausalLM",
"GalacticaSharded",
"Seq2SeqLM", "Seq2SeqLM",
"get_model", "get_model",
] ]
@ -125,7 +124,7 @@ try:
) )
from text_generation_server.layers.attention import SUPPORTS_WINDOWING from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False FLASH_ATTENTION = False
@ -137,7 +136,7 @@ MAMBA_AVAILABLE = True
try: try:
from text_generation_server.models.mamba import Mamba from text_generation_server.models.mamba import Mamba
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Mamba: {e}") log_master(logger.warning, f"Could not import Mamba: {e}")
MAMBA_AVAILABLE = False MAMBA_AVAILABLE = False
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
@ -312,8 +311,11 @@ def get_model(
# These quantizers only work with float16 params. # These quantizers only work with float16 params.
dtype = torch.float16 dtype = torch.float16
elif quantize == "fp8": elif quantize == "fp8":
# gemm kernels are fp8xfp8->bf16 from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE
dtype = torch.bfloat16
if FBGEMM_MM_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
else: else:
# Keep it as default for now and let # Keep it as default for now and let
# every model resolve their own default dtype. # every model resolve their own default dtype.
@ -436,7 +438,9 @@ def get_model(
speculate = get_speculate() speculate = get_speculate()
if speculate > 0: if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.") log_master(
logger.info, f"Using speculation {method} with {speculate} input ids."
)
if model_type is None: if model_type is None:
# TODO: fix how we determine model type for Mamba # TODO: fix how we determine model type for Mamba
@ -451,10 +455,10 @@ def get_model(
if quantization_config is not None and quantize is None: if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None) method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}: if method in {"gptq", "awq", "exl2"}:
logger.info(f"Auto selecting quantization method {method}") log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method quantize = method
else: else:
logger.info(f"Unknown quantization method {method}") log_master(logger.warning, f"Unknown quantization method {method}")
if quantize == "exl2" and sharded: if quantize == "exl2" and sharded:
raise RuntimeError( raise RuntimeError(
@ -596,7 +600,7 @@ def get_model(
) )
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}") log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return CausalLM.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,

View File

@ -418,7 +418,22 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=(
"model.layers.0" if not prefix else "{prefix}.model.layers.0"
),
config=config,
weights=weights,
)
)
self.layers.extend(
[ [
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id, index=layer_id,
@ -430,9 +445,26 @@ class FlashLlamaModel(torch.nn.Module):
config=config, config=config,
weights=weights, weights=weights,
) )
for layer_id in range(config.num_hidden_layers) # Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1)
] ]
) )
with no_fp8(weights):
last_layer_id = config.num_hidden_layers - 1
self.layers.append(
FlashLlamaLayer(
index=last_layer_id,
prefix=(
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.model.layers.{last_layer_id}"
),
config=config,
weights=weights,
)
)
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm", prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights, weights=weights,

View File

@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.log import log_master
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
hub,
) )
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -1156,31 +1155,36 @@ class FlashCausalLM(Model):
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
) )
logger.info( log_master(
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." logger.info,
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
) )
if os.path.isfile(tunableop_filepath): if os.path.isfile(tunableop_filepath):
logger.info( log_master(
f"The file {tunableop_filepath} already exists and will be reused." logger.info,
f"The file {tunableop_filepath} already exists and will be reused.",
) )
torch.cuda.tunable.read_file(tunableop_filepath) torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences: for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}") log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
else: else:
logger.info( log_master(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." logger.info,
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
) )
if CUDA_GRAPHS: if CUDA_GRAPHS:
try: try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") log_master(
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
)
# Warmup cuda graphs # Warmup cuda graphs
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
@ -1188,7 +1192,9 @@ class FlashCausalLM(Model):
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
)
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
@ -1540,8 +1546,7 @@ class FlashCausalLM(Model):
left = 0 left = 0
if n_accepted_ids > 1: if n_accepted_ids > 1:
if RANK == 0: log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
logger.debug(f"Speculated ids {n_accepted_ids - 1}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):

View File

@ -1,15 +1,16 @@
import torch import torch
import os import os
from loguru import logger from loguru import logger
from typing import Dict from typing import Dict, Optional
from text_generation_server.utils.log import log_master
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING: if FLASH_DECODING:
logger.info("Using FLASH_DECODING") log_master(logger.info, "Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
@ -26,11 +27,9 @@ else:
if cuda_graphs is not None: if cuda_graphs is not None:
cuda_graphs.sort(reverse=True) cuda_graphs.sort(reverse=True)
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading. # This is overridden at model loading.
global MODEL_ID
MODEL_ID = None MODEL_ID = None
@ -41,8 +40,7 @@ def set_model_id(model_id: str):
# NOTE: eventually we should move this into the router and pass back the # NOTE: eventually we should move this into the router and pass back the
# index in all cases. # index in all cases.
global ADAPTER_TO_INDEX ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]): def set_adapter_to_index(adapter_to_index: Dict[str, int]):

View File

@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import (
AdapterParameters, AdapterParameters,
AdapterSource, AdapterSource,
) )
from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
@ -204,8 +205,9 @@ class Model(ABC):
f"order to use the dynamic adapter loading feature." f"order to use the dynamic adapter loading feature."
) )
logger.info( log_master(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" logger.info,
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
) )
weight_names = tuple([v[0] for v in self.target_to_layer.values()]) weight_names = tuple([v[0] for v in self.target_to_layer.values()])
( (
@ -240,8 +242,9 @@ class Model(ABC):
layer_weights.add_adapter(adapter_index, adapter_weights) layer_weights.add_adapter(adapter_index, adapter_weights)
if len(unused_weight_names) > 0: if len(unused_weight_names) > 0:
logger.warning( log_master(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" logger.warning,
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
) )
if adapter_tokenizer is not None: if adapter_tokenizer is not None:

View File

@ -1,4 +1,3 @@
from itertools import repeat
import torch import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
FlashCausalLM, FlashCausalLM,
) )
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger from loguru import logger
logger.info( log_master(
f"Found {num_features} features in image of resolution {height}x{width}" logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
) )
return "<image>" * num_features return "<image>" * num_features

View File

@ -1,6 +1,15 @@
from functools import lru_cache from functools import lru_cache
from text_generation_server.utils.dist import RANK
@lru_cache(10) @lru_cache(10)
def log_once(log, msg: str): def log_once(log, msg: str, master=True):
log(msg) if master:
log_master(log, msg)
else:
log(msg)
def log_master(log, msg: str):
if RANK == 0:
log(msg)

View File

@ -11,6 +11,7 @@ from text_generation_server.utils.weights import (
) )
# TODO: Split this config to have a single config type per quant method
@dataclass @dataclass
class _QuantizerConfig: class _QuantizerConfig:
bits: int bits: int
@ -21,6 +22,11 @@ class _QuantizerConfig:
sym: bool sym: bool
@dataclass
class _FP8QuantizerConfig:
activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization, # We should probably do this with Pytantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params. # but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision): def _get_quantizer_config(model_id, revision):
@ -63,6 +69,17 @@ def _get_quantizer_config(model_id, revision):
desc_act = data["desc_act"] desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM": if "version" in data and data["version"] == "GEMM":
quant_method = "awq" quant_method = "awq"
# FP8 config
except KeyError:
try:
filename = os.path.join(model_id, filename)
with open(filename, "r") as f:
data = json.load(f)
return _FP8QuantizerConfig(
activation_scale_ub=data["activation_scale_ub"]
)
except:
pass
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
@ -99,6 +116,12 @@ def get_loader(
if quantize in {"awq", "gptq"}: if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return GPTQWeightsLoader( return GPTQWeightsLoader(
bits=quantizer_config.bits, bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act, desc_act=quantizer_config.desc_act,
@ -127,18 +150,28 @@ def get_loader(
from text_generation_server.layers.exl2 import Exl2WeightsLoader from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader() return Exl2WeightsLoader()
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Weight
return DefaultWeightsLoader(Fp8Weight)
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader from text_generation_server.layers.marlin import MarlinWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return MarlinWeightsLoader( return MarlinWeightsLoader(
bits=quantizer_config.bits, bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
) )
elif quantize is None: elif quantize == "fp8" or quantize is None:
return DefaultWeightsLoader() from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
# Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error
activation_scale_ub = None
if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub)
else: else:
raise ValueError(f"Unknown quantization method: {quantize}") raise ValueError(f"Unknown quantization method: {quantize}")

View File

@ -3,7 +3,7 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union, Type
from safetensors import safe_open from safetensors import safe_open
from dataclasses import dataclass from dataclasses import dataclass
@ -84,9 +84,8 @@ class Weight(ABC):
@dataclass @dataclass
class UnquantizedWeight: class UnquantizedWeight(Weight):
weight: torch.Tensor weight: torch.Tensor
dtype: torch.dtype
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.linear import FastLinear, FastLinearROCm from text_generation_server.layers.linear import FastLinear, FastLinearROCm
@ -100,7 +99,7 @@ class UnquantizedWeight:
class DefaultWeightsLoader(WeightsLoader): class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors.""" """Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class: Optional = None): def __init__(self, weight_class: Type[UnquantizedWeight]):
"""Create a loader. Weights will be wrapped using the given `weights_class`, """Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading. such as `Fp8Weight` can be used to quantize the weights during loading.
@ -121,92 +120,21 @@ class DefaultWeightsLoader(WeightsLoader):
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes return self.weight_class(
weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
),
) )
if w.dtype == torch.float8_e4m3fn:
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
)
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, cast=False
)
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
dtype=weights.dtype,
)
if self.weight_class is None:
return UnquantizedWeight(w, dtype=weights.dtype)
return self.weight_class(w, dtype=weights.dtype)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
w = torch.cat(w, dim=dim) return self.weight_class(torch.cat(w, dim=dim))
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
)
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, cast=False)
for p in prefixes
]
scale = torch.cat(scale, dim=0)
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale", cast=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
dtype=weights.dtype,
)
if self.weight_class is None:
return UnquantizedWeight(w, dtype=weights.dtype)
return self.weight_class(w, dtype=weights.dtype)
def get_weights_row(self, weights: "Weights", prefix: str): def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1) return self.weight_class(
# FP8 branch weights.get_sharded(f"{prefix}.weight", dim=1),
if w.dtype == torch.float8_e4m3fn: )
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
)
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, cast=False)
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
dtype=weights.dtype,
)
if self.weight_class is None:
return UnquantizedWeight(w, dtype=weights.dtype)
return self.weight_class(w, dtype=weights.dtype)
class Weights: class Weights:
@ -280,7 +208,7 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True, cast=True): def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
@ -295,14 +223,14 @@ class Weights:
torch.int32, torch.int32,
torch.int64, torch.int64,
] ]
and cast and to_dtype
): ):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device: if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_partial_sharded(self, tensor_name: str, dim: int, cast=True): def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -323,12 +251,15 @@ class Weights:
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16. # u4 which are disguised as int32. exl2 uses int16.
# FP8 uses torch.float8_e4m3fn. # FP8 uses torch.float8_e4m3fn.
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) and cast: if (
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
and to_dtype
):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int, cast=True): def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -337,10 +268,14 @@ class Weights:
assert ( assert (
size % world_size == 0 size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim, cast=cast) return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
def get_packed_sharded( def get_packed_sharded(
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]], cast=True self,
tensor_name: str,
dim: int,
block_sizes: Union[int, List[int]],
to_dtype=True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Get a shard from a tensor that packs multiple tensors. Get a shard from a tensor that packs multiple tensors.
@ -394,7 +329,7 @@ class Weights:
torch.int32, torch.int32,
torch.int64, torch.int64,
] ]
and cast and to_dtype
): ):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)