mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
working
This commit is contained in:
parent
0ca83be883
commit
f723e5ccb5
@ -194,16 +194,11 @@ def attention(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||||
logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}")
|
|
||||||
logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}")
|
|
||||||
logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}")
|
|
||||||
logger.info(f"cu_seqlens {cu_seqlens}")
|
|
||||||
logger.info(f"max_s {max_s}")
|
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
None,
|
out,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
@ -211,8 +206,6 @@ def attention(
|
|||||||
True,
|
True,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
)
|
)
|
||||||
logger.info(f"output shape {output.shape} {output.dtype}")
|
|
||||||
logger.info(f"output {output}")
|
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
||||||
|
Loading…
Reference in New Issue
Block a user