cleanup fastlinear

This commit is contained in:
fxmarty 2024-05-17 09:02:14 +00:00
parent 3ded96fb4c
commit 956ac30ab9

View File

@ -61,16 +61,17 @@ class FastLinearROCm(torch.nn.Module):
weight = self.weight
bias = self.bias
if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1:
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
batched = False
inp_shape = inp.shape
if inp.dim() == 3:
inp = inp.view(-1, inp.size(-1))
inp = inp.view(-1, inp_shape[-1])
batched = True
m, k = weight.shape[0], inp.shape[1]
m, k = weight.shape[0], inp_shape[1]
out = torch.empty(
inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
)
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
_custom_C.LLMM1(weight, inp, out, 8)
@ -78,8 +79,10 @@ class FastLinearROCm(torch.nn.Module):
_custom_C.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)
if batched:
out = out.view(inp.shape[0], inp.shape[1], weight.shape[0])
out.view(*inp_shape[:-1], out.shape[-1])
if bias is not None:
out = out + bias
return out