mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fix(server): fix flash neox rotary embeddings (#150)
This commit is contained in:
parent
610bb1f978
commit
08b7e4a282
@ -319,12 +319,12 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
layer_past[...] = qkv_rot[:, 1:]
|
layer_past[...] = qkv_rot[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv_rot[:, 0])
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
qkv_rot[:, 0],
|
||||||
qkv[:, 1],
|
qkv_rot[:, 1],
|
||||||
qkv[:, 2],
|
qkv_rot[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
Loading…
Reference in New Issue
Block a user