From 5373fc4707c0dbbb42d94e2994fa70fcfd628120 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 27 Apr 2024 11:33:24 +0200 Subject: [PATCH] Fix tests. --- docs/source/basic_tutorials/launcher.md | 2 +- .../models/flash_causal_lm.py | 13 ++++++++++--- server/text_generation_server/models/flash_llama.py | 1 + server/text_generation_server/models/flash_neox.py | 1 + server/text_generation_server/models/flash_rw.py | 1 + .../models/flash_santacoder.py | 1 + server/text_generation_server/utils/import_utils.py | 2 ++ 7 files changed, 17 insertions(+), 4 deletions(-) diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index de7c995d..1e5b6fd2 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -162,7 +162,7 @@ Options: This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`. [env: WAITING_SERVED_RATIO=] - [default: 1.2] + [default: 0.3] ``` ## MAX_BATCH_PREFILL_TOKENS diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 94518b8f..a6d0204f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -33,7 +33,12 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) + @dataclass class FlashCausalLMBatch(Batch): @@ -788,14 +793,16 @@ class FlashCausalLM(Model): if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: total_free_memory, _ = torch.cuda.mem_get_info(self.device) - total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + total_gpu_memory = torch.cuda.get_device_properties( + self.device + ).total_memory free_memory = max( 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory ) elif IS_XPU_SYSTEM: total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory - free_memory = int(total_gpu_memory *0.5) + free_memory = int(total_gpu_memory * 0.5) else: raise NotImplementedError("FlashModel is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index f37fc542..609a188d 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -20,6 +20,7 @@ tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + class FlashLlama(FlashCausalLM): def __init__( self, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 70c978de..f82e27db 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -15,6 +15,7 @@ from text_generation_server.utils import ( Weights, ) from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + tracer = trace.get_tracer(__name__) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 6eb25f22..ccf38a0c 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -16,6 +16,7 @@ from text_generation_server.utils import ( Weights, ) from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + tracer = trace.get_tracer(__name__) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 6147398a..e66f1bf8 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + tracer = trace.get_tracer(__name__) diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 7c0d8001..db205f4d 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,5 +1,6 @@ import torch + def is_xpu_available(): try: import intel_extension_for_pytorch @@ -8,6 +9,7 @@ def is_xpu_available(): return hasattr(torch, "xpu") and torch.xpu.is_available() + IS_ROCM_SYSTEM = torch.version.hip is not None IS_CUDA_SYSTEM = torch.version.cuda is not None IS_XPU_SYSTEM = is_xpu_available()