fix: repack for marlin when single scale is provided

This commit is contained in:
drbh 2024-08-13 16:52:15 -04:00
parent 1cebccc72b
commit ab4d480d91

View File

@ -39,7 +39,8 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0) scales = scales.unsqueeze(0)
if scales.shape[1] == 1: # repack weights for Marlin if a single scale is provided
if scales.size(0) == 1:
out_features, in_features = qweight.shape out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features) scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales) qweight, scales = repack_fp8_for_marlin(qweight, scales)