This commit is contained in:
fxmarty 2024-04-19 11:57:16 +00:00
parent 1b4c8b4b3e
commit ec5343ec5e
2 changed files with 5 additions and 4 deletions

View File

@ -28,10 +28,10 @@ ROCM_USE_FLASH_ATTN_V2_TRITON = False
if IS_ROCM_SYSTEM: if IS_ROCM_SYSTEM:
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true": if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
ROCM_USE_FLASH_ATTN_V2_TRITON = True ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementaion.") logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else: else:
ROCM_USE_FLASH_ATTN_V2_CK = True ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementaion.") logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
try: try:
try: try:

View File

@ -1291,5 +1291,6 @@ try:
freqs = torch.outer(t, self.inv_freq.to(device=t.device)) freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
except ImportError as e:
raise e except ImportError:
pass