Clarify FP8-Marlin use on capability 8.9 (#2940)

The log message stated that the GPU does not support FP8 on capability
8.9. However we use FP8-Marlin on that capability because it is faster.
This commit is contained in:
Daniël de Kok 2025-01-22 18:18:11 +01:00 committed by GitHub
parent 1d3c9beba8
commit 1dd346666a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 4 deletions

View File

@ -52,6 +52,16 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
# gives better decoding throughput on L4 and L40. # gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
if major == 8 and minor == 9:
log_once(
logger.info,
"GPU supports FP8, but using Marlin FP8 kernel for better performance",
)
else:
log_once(
logger.info, "GPU does not support FP8, using Marlin FP8 kernel"
)
return GPTQMarlinFP8Linear return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8. # On other systems let Torch decide if the hardware supports FP8.

View File

@ -2,14 +2,12 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import ( from text_generation_server.layers.marlin.util import (
_check_marlin_kernels, _check_marlin_kernels,
permute_scales, permute_scales,
) )
from text_generation_server.utils.log import log_once
try: try:
import marlin_kernels import marlin_kernels
@ -36,8 +34,6 @@ class GPTQMarlinFP8Linear(nn.Module):
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0) scales = scales.unsqueeze(0)
if scales.shape[1] == 1: if scales.shape[1] == 1:
out_features, in_features = qweight.shape out_features, in_features = qweight.shape