mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Clarify FP8-Marlin use on capability 8.9 (#2940)
The log message stated that the GPU does not support FP8 on capability 8.9. However we use FP8-Marlin on that capability because it is faster.
This commit is contained in:
parent
1d3c9beba8
commit
1dd346666a
@ -52,6 +52,16 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
|||||||
# gives better decoding throughput on L4 and L40.
|
# gives better decoding throughput on L4 and L40.
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||||
|
|
||||||
|
if major == 8 and minor == 9:
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"GPU supports FP8, but using Marlin FP8 kernel for better performance",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_once(
|
||||||
|
logger.info, "GPU does not support FP8, using Marlin FP8 kernel"
|
||||||
|
)
|
||||||
|
|
||||||
return GPTQMarlinFP8Linear
|
return GPTQMarlinFP8Linear
|
||||||
|
|
||||||
# On other systems let Torch decide if the hardware supports FP8.
|
# On other systems let Torch decide if the hardware supports FP8.
|
||||||
|
@ -2,14 +2,12 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from loguru import logger
|
|
||||||
from text_generation_server.layers.fp8 import fp8_quantize
|
from text_generation_server.layers.fp8 import fp8_quantize
|
||||||
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
||||||
from text_generation_server.layers.marlin.util import (
|
from text_generation_server.layers.marlin.util import (
|
||||||
_check_marlin_kernels,
|
_check_marlin_kernels,
|
||||||
permute_scales,
|
permute_scales,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_once
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import marlin_kernels
|
import marlin_kernels
|
||||||
@ -36,8 +34,6 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
_check_marlin_kernels()
|
_check_marlin_kernels()
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
|
||||||
|
|
||||||
scales = scales.unsqueeze(0)
|
scales = scales.unsqueeze(0)
|
||||||
if scales.shape[1] == 1:
|
if scales.shape[1] == 1:
|
||||||
out_features, in_features = qweight.shape
|
out_features, in_features = qweight.shape
|
||||||
|
Loading…
Reference in New Issue
Block a user