From 1dd346666a3b354f2c7542141bc1928a47c452db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 22 Jan 2025 18:18:11 +0100 Subject: [PATCH] Clarify FP8-Marlin use on capability 8.9 (#2940) The log message stated that the GPU does not support FP8 on capability 8.9. However we use FP8-Marlin on that capability because it is faster. --- server/text_generation_server/layers/fp8.py | 10 ++++++++++ server/text_generation_server/layers/marlin/fp8.py | 4 ---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 4e83ec9d..715974ff 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -52,6 +52,16 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: # gives better decoding throughput on L4 and L40. from text_generation_server.layers.marlin import GPTQMarlinFP8Linear + if major == 8 and minor == 9: + log_once( + logger.info, + "GPU supports FP8, but using Marlin FP8 kernel for better performance", + ) + else: + log_once( + logger.info, "GPU does not support FP8, using Marlin FP8 kernel" + ) + return GPTQMarlinFP8Linear # On other systems let Torch decide if the hardware supports FP8. diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index 49f5c480..e07b9fc6 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -2,14 +2,12 @@ from typing import Optional import torch import torch.nn as nn -from loguru import logger from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.marlin.gptq import _check_valid_shape from text_generation_server.layers.marlin.util import ( _check_marlin_kernels, permute_scales, ) -from text_generation_server.utils.log import log_once try: import marlin_kernels @@ -36,8 +34,6 @@ class GPTQMarlinFP8Linear(nn.Module): _check_marlin_kernels() assert marlin_kernels is not None - log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - scales = scales.unsqueeze(0) if scales.shape[1] == 1: out_features, in_features = qweight.shape