fix: return the out tensor rather then the functions return value (#2361)

This commit is contained in:
drbh 2024-08-06 07:49:53 -04:00 committed by yuanwu
parent 8b0f5feb02
commit 83d1f23fea

View File

@ -292,8 +292,7 @@ else:
)
out = torch.empty_like(q)
return flash_attn_cuda.fwd(
flash_attn_cuda.fwd(
q,
k,
v,
@ -309,4 +308,5 @@ else:
False,
0,
None,
)[0]
)
return out