mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
cleanup fastlinear
This commit is contained in:
parent
3ded96fb4c
commit
956ac30ab9
@ -61,16 +61,17 @@ class FastLinearROCm(torch.nn.Module):
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
|
||||
if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1:
|
||||
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
||||
batched = False
|
||||
inp_shape = inp.shape
|
||||
|
||||
if inp.dim() == 3:
|
||||
inp = inp.view(-1, inp.size(-1))
|
||||
inp = inp.view(-1, inp_shape[-1])
|
||||
batched = True
|
||||
|
||||
m, k = weight.shape[0], inp.shape[1]
|
||||
m, k = weight.shape[0], inp_shape[1]
|
||||
out = torch.empty(
|
||||
inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||
)
|
||||
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||
_custom_C.LLMM1(weight, inp, out, 8)
|
||||
@ -78,8 +79,10 @@ class FastLinearROCm(torch.nn.Module):
|
||||
_custom_C.LLMM1(weight, inp, out, 4)
|
||||
else:
|
||||
out = F.linear(inp, weight)
|
||||
|
||||
if batched:
|
||||
out = out.view(inp.shape[0], inp.shape[1], weight.shape[0])
|
||||
out.view(*inp_shape[:-1], out.shape[-1])
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
|
Loading…
Reference in New Issue
Block a user