fix missing g_idx and eventual overflow in triton kernel

This commit is contained in:
IlyasMoutawwakil 2024-02-01 13:30:43 +00:00
parent 8acbcb31d5
commit 0b5b858779

View File

@ -182,7 +182,7 @@ try:
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1 zeros = (zeros + 1) & maxq # add 1 and avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
@ -251,7 +251,17 @@ class QuantLinear(nn.Module):
self.register_buffer("qweight", qweight) self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros) self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales) self.register_buffer("scales", scales)
if g_idx is not None:
self.register_buffer("g_idx", g_idx) self.register_buffer("g_idx", g_idx)
else:
self.register_buffer(
"g_idx",
torch.tensor(
[i // groupsize for i in range(qweight.shape[0] * 32 // bits)],
device=qweight.device,
dtype=torch.int32,
),
)
if bias is not None: if bias is not None:
self.register_buffer("bias", bias) self.register_buffer("bias", bias)
else: else: