diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 2b898831..96b654d0 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -292,8 +292,7 @@ else: ) out = torch.empty_like(q) - - return flash_attn_cuda.fwd( + flash_attn_cuda.fwd( q, k, v, @@ -309,4 +308,5 @@ else: False, 0, None, - )[0] + ) + return out