diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 94518b8f..a6d0204f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -33,7 +33,12 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) + @dataclass class FlashCausalLMBatch(Batch): @@ -788,14 +793,16 @@ class FlashCausalLM(Model): if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: total_free_memory, _ = torch.cuda.mem_get_info(self.device) - total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + total_gpu_memory = torch.cuda.get_device_properties( + self.device + ).total_memory free_memory = max( 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory ) elif IS_XPU_SYSTEM: total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory - free_memory = int(total_gpu_memory *0.5) + free_memory = int(total_gpu_memory * 0.5) else: raise NotImplementedError("FlashModel is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index f37fc542..609a188d 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -20,6 +20,7 @@ tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + class FlashLlama(FlashCausalLM): def __init__( self, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 70c978de..f82e27db 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -15,6 +15,7 @@ from text_generation_server.utils import ( Weights, ) from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + tracer = trace.get_tracer(__name__) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 6eb25f22..ccf38a0c 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -16,6 +16,7 @@ from text_generation_server.utils import ( Weights, ) from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + tracer = trace.get_tracer(__name__) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 6147398a..e66f1bf8 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + tracer = trace.get_tracer(__name__) diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 7c0d8001..db205f4d 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,5 +1,6 @@ import torch + def is_xpu_available(): try: import intel_extension_for_pytorch @@ -8,6 +9,7 @@ def is_xpu_available(): return hasattr(torch, "xpu") and torch.xpu.is_available() + IS_ROCM_SYSTEM = torch.version.hip is not None IS_CUDA_SYSTEM = torch.version.cuda is not None IS_XPU_SYSTEM = is_xpu_available() diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 638cb0a0..8c46ea49 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -8,6 +8,7 @@ from typing import List, Tuple, Optional from loguru import logger from functools import lru_cache +# Dummy comment. HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb