This commit is contained in:
fxmarty 2024-04-19 11:23:27 +00:00
parent 0ca83be883
commit f723e5ccb5

View File

@ -194,16 +194,11 @@ def attention(
None,
)
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}")
logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}")
logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}")
logger.info(f"cu_seqlens {cu_seqlens}")
logger.info(f"max_s {max_s}")
output, _ = triton_attention(
q,
k,
v,
None,
out,
cu_seqlens,
cu_seqlens,
max_s,
@ -211,8 +206,6 @@ def attention(
True,
softmax_scale,
)
logger.info(f"output shape {output.shape} {output.dtype}")
logger.info(f"output {output}")
return output
else:
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")