This commit is contained in:
fxmarty 2024-04-19 11:23:27 +00:00
parent 0ca83be883
commit f723e5ccb5

View File

@ -194,16 +194,11 @@ def attention(
None, None,
) )
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON: elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}")
logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}")
logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}")
logger.info(f"cu_seqlens {cu_seqlens}")
logger.info(f"max_s {max_s}")
output, _ = triton_attention( output, _ = triton_attention(
q, q,
k, k,
v, v,
None, out,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
max_s, max_s,
@ -211,8 +206,6 @@ def attention(
True, True,
softmax_scale, softmax_scale,
) )
logger.info(f"output shape {output.shape} {output.dtype}")
logger.info(f"output {output}")
return output return output
else: else:
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})") raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")