mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
working
This commit is contained in:
parent
0ca83be883
commit
f723e5ccb5
@ -194,16 +194,11 @@ def attention(
|
||||
None,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||
logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}")
|
||||
logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}")
|
||||
logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}")
|
||||
logger.info(f"cu_seqlens {cu_seqlens}")
|
||||
logger.info(f"max_s {max_s}")
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
@ -211,8 +206,6 @@ def attention(
|
||||
True,
|
||||
softmax_scale,
|
||||
)
|
||||
logger.info(f"output shape {output.shape} {output.dtype}")
|
||||
logger.info(f"output {output}")
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
||||
|
Loading…
Reference in New Issue
Block a user