fix missing g_idx and eventual overflow in triton kernel

This commit is contained in:
IlyasMoutawwakil 2024-02-01 13:30:43 +00:00 committed by Nicolas Patry
parent 3963074ceb
commit 3ceeb85842

View File

@ -182,7 +182,7 @@ try:
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
zeros = (zeros + 1) & maxq # add 1 and avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
@ -251,7 +251,17 @@ class QuantLinear(nn.Module):
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
if g_idx is not None:
self.register_buffer("g_idx", g_idx)
else:
self.register_buffer(
"g_idx",
torch.tensor(
[i // groupsize for i in range(qweight.shape[0] * 32 // bits)],
device=qweight.device,
dtype=torch.int32,
),
)
if bias is not None:
self.register_buffer("bias", bias)
else: