From ab4d480d91b2173c7dda7c47122bb1250c9b95aa Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 13 Aug 2024 16:52:15 -0400 Subject: [PATCH] fix: repack for marlin when single scale is provided --- server/text_generation_server/layers/marlin/fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index fe55a58a3..827e47df6 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -39,7 +39,8 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") scales = scales.unsqueeze(0) - if scales.shape[1] == 1: + # repack weights for Marlin if a single scale is provided + if scales.size(0) == 1: out_features, in_features = qweight.shape scales = scales.repeat(1, out_features) qweight, scales = repack_fp8_for_marlin(qweight, scales)