mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
cleanup fastlinear
This commit is contained in:
parent
3ded96fb4c
commit
956ac30ab9
@ -61,16 +61,17 @@ class FastLinearROCm(torch.nn.Module):
|
|||||||
weight = self.weight
|
weight = self.weight
|
||||||
bias = self.bias
|
bias = self.bias
|
||||||
|
|
||||||
if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1:
|
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
||||||
batched = False
|
batched = False
|
||||||
|
inp_shape = inp.shape
|
||||||
|
|
||||||
if inp.dim() == 3:
|
if inp.dim() == 3:
|
||||||
inp = inp.view(-1, inp.size(-1))
|
inp = inp.view(-1, inp_shape[-1])
|
||||||
batched = True
|
batched = True
|
||||||
|
|
||||||
m, k = weight.shape[0], inp.shape[1]
|
m, k = weight.shape[0], inp_shape[1]
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||||
)
|
)
|
||||||
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||||
_custom_C.LLMM1(weight, inp, out, 8)
|
_custom_C.LLMM1(weight, inp, out, 8)
|
||||||
@ -78,8 +79,10 @@ class FastLinearROCm(torch.nn.Module):
|
|||||||
_custom_C.LLMM1(weight, inp, out, 4)
|
_custom_C.LLMM1(weight, inp, out, 4)
|
||||||
else:
|
else:
|
||||||
out = F.linear(inp, weight)
|
out = F.linear(inp, weight)
|
||||||
|
|
||||||
if batched:
|
if batched:
|
||||||
out = out.view(inp.shape[0], inp.shape[1], weight.shape[0])
|
out.view(*inp_shape[:-1], out.shape[-1])
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
return out
|
return out
|
||||||
|
Loading…
Reference in New Issue
Block a user