mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Clarify FP8-Marlin use on capability 8.9 (#2940)
The log message stated that the GPU does not support FP8 on capability 8.9. However we use FP8-Marlin on that capability because it is faster.
This commit is contained in:
parent
1d3c9beba8
commit
1dd346666a
@ -52,6 +52,16 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||
# gives better decoding throughput on L4 and L40.
|
||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||
|
||||
if major == 8 and minor == 9:
|
||||
log_once(
|
||||
logger.info,
|
||||
"GPU supports FP8, but using Marlin FP8 kernel for better performance",
|
||||
)
|
||||
else:
|
||||
log_once(
|
||||
logger.info, "GPU does not support FP8, using Marlin FP8 kernel"
|
||||
)
|
||||
|
||||
return GPTQMarlinFP8Linear
|
||||
|
||||
# On other systems let Torch decide if the hardware supports FP8.
|
||||
|
@ -2,14 +2,12 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
from text_generation_server.layers.fp8 import fp8_quantize
|
||||
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
||||
from text_generation_server.layers.marlin.util import (
|
||||
_check_marlin_kernels,
|
||||
permute_scales,
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
@ -36,8 +34,6 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
|
||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||
|
||||
scales = scales.unsqueeze(0)
|
||||
if scales.shape[1] == 1:
|
||||
out_features, in_features = qweight.shape
|
||||
|
Loading…
Reference in New Issue
Block a user