diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 71577306..8b0b72c3 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -194,16 +194,11 @@ def attention( None, ) elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON: - logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}") - logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}") - logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}") - logger.info(f"cu_seqlens {cu_seqlens}") - logger.info(f"max_s {max_s}") output, _ = triton_attention( q, k, v, - None, + out, cu_seqlens, cu_seqlens, max_s, @@ -211,8 +206,6 @@ def attention( True, softmax_scale, ) - logger.info(f"output shape {output.shape} {output.dtype}") - logger.info(f"output {output}") return output else: raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")