fix: return the out tensor rather then the functions return value (#2361)

This commit is contained in:
drbh 2024-08-06 07:49:53 -04:00 committed by GitHub
parent dd47a3dac4
commit 29b8d19cdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -292,8 +292,7 @@ else:
) )
out = torch.empty_like(q) out = torch.empty_like(q)
flash_attn_cuda.fwd(
return flash_attn_cuda.fwd(
q, q,
k, k,
v, v,
@ -309,4 +308,5 @@ else:
False, False,
0, 0,
None, None,
)[0] )
return out