mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix missing g_idx and eventual overflow in triton kernel
This commit is contained in:
parent
8acbcb31d5
commit
0b5b858779
@ -182,7 +182,7 @@ try:
|
|||||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
|
|
||||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||||
zeros = zeros + 1
|
zeros = (zeros + 1) & maxq # add 1 and avoid overflow
|
||||||
|
|
||||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||||
@ -251,7 +251,17 @@ class QuantLinear(nn.Module):
|
|||||||
self.register_buffer("qweight", qweight)
|
self.register_buffer("qweight", qweight)
|
||||||
self.register_buffer("qzeros", qzeros)
|
self.register_buffer("qzeros", qzeros)
|
||||||
self.register_buffer("scales", scales)
|
self.register_buffer("scales", scales)
|
||||||
self.register_buffer("g_idx", g_idx)
|
if g_idx is not None:
|
||||||
|
self.register_buffer("g_idx", g_idx)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"g_idx",
|
||||||
|
torch.tensor(
|
||||||
|
[i // groupsize for i in range(qweight.shape[0] * 32 // bits)],
|
||||||
|
device=qweight.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
self.register_buffer("bias", bias)
|
self.register_buffer("bias", bias)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user