From 08b7e4a28232c71b97a5e1fc90d50be9a8e6c6f1 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 30 Mar 2023 16:12:23 +0200 Subject: [PATCH] fix(server): fix flash neox rotary embeddings (#150) --- .../text_generation_server/models/flash_neox_modeling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index ac07aa98..f3517c47 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -319,12 +319,12 @@ class FlashNeoxAttention(torch.nn.Module): layer_past[...] = qkv_rot[:, 1:] # output - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(qkv_rot[:, 0]) # flash attention flash_attn_cuda.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + qkv_rot[:, 0], + qkv_rot[:, 1], + qkv_rot[:, 2], attn_output, cu_seqlens, cu_seqlens,