diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index ac07aa98..f3517c47 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -319,12 +319,12 @@ class FlashNeoxAttention(torch.nn.Module): layer_past[...] = qkv_rot[:, 1:] # output - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(qkv_rot[:, 0]) # flash attention flash_attn_cuda.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + qkv_rot[:, 0], + qkv_rot[:, 1], + qkv_rot[:, 2], attn_output, cu_seqlens, cu_seqlens,