From 956ac30ab9572b9f6cd5bee93c8a22d8fe2ccced Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 May 2024 09:02:14 +0000 Subject: [PATCH] cleanup fastlinear --- server/text_generation_server/layers/linear.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 987b6a7b..5bd6aa95 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -61,16 +61,17 @@ class FastLinearROCm(torch.nn.Module): weight = self.weight bias = self.bias - if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1: + if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1: batched = False + inp_shape = inp.shape if inp.dim() == 3: - inp = inp.view(-1, inp.size(-1)) + inp = inp.view(-1, inp_shape[-1]) batched = True - m, k = weight.shape[0], inp.shape[1] + m, k = weight.shape[0], inp_shape[1] out = torch.empty( - inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" + inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" ) if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): _custom_C.LLMM1(weight, inp, out, 8) @@ -78,8 +79,10 @@ class FastLinearROCm(torch.nn.Module): _custom_C.LLMM1(weight, inp, out, 4) else: out = F.linear(inp, weight) + if batched: - out = out.view(inp.shape[0], inp.shape[1], weight.shape[0]) + out.view(*inp_shape[:-1], out.shape[-1]) + if bias is not None: out = out + bias return out