mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
cleaning
This commit is contained in:
parent
1b4c8b4b3e
commit
ec5343ec5e
@ -28,10 +28,10 @@ ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
|||||||
if IS_ROCM_SYSTEM:
|
if IS_ROCM_SYSTEM:
|
||||||
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
|
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
||||||
logger.info("ROCm: using Flash Attention 2 Triton implementaion.")
|
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
||||||
else:
|
else:
|
||||||
ROCM_USE_FLASH_ATTN_V2_CK = True
|
ROCM_USE_FLASH_ATTN_V2_CK = True
|
||||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementaion.")
|
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
|
@ -1291,5 +1291,6 @@ try:
|
|||||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||||
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||||
except ImportError as e:
|
|
||||||
raise e
|
except ImportError:
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user