diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 8b0b72c3..791d705c 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -28,10 +28,10 @@ ROCM_USE_FLASH_ATTN_V2_TRITON = False if IS_ROCM_SYSTEM: if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true": ROCM_USE_FLASH_ATTN_V2_TRITON = True - logger.info("ROCm: using Flash Attention 2 Triton implementaion.") + logger.info("ROCm: using Flash Attention 2 Triton implementation.") else: ROCM_USE_FLASH_ATTN_V2_CK = True - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementaion.") + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") try: try: diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 2f1a2b64..8e36f654 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1291,5 +1291,6 @@ try: freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) -except ImportError as e: - raise e + +except ImportError: + pass \ No newline at end of file