diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 987b6a7b..5bd6aa95 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -61,16 +61,17 @@ class FastLinearROCm(torch.nn.Module): weight = self.weight bias = self.bias - if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1: + if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1: batched = False + inp_shape = inp.shape if inp.dim() == 3: - inp = inp.view(-1, inp.size(-1)) + inp = inp.view(-1, inp_shape[-1]) batched = True - m, k = weight.shape[0], inp.shape[1] + m, k = weight.shape[0], inp_shape[1] out = torch.empty( - inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" + inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" ) if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): _custom_C.LLMM1(weight, inp, out, 8) @@ -78,8 +79,10 @@ class FastLinearROCm(torch.nn.Module): _custom_C.LLMM1(weight, inp, out, 4) else: out = F.linear(inp, weight) + if batched: - out = out.view(inp.shape[0], inp.shape[1], weight.shape[0]) + out.view(*inp_shape[:-1], out.shape[-1]) + if bias is not None: out = out + bias return out