diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index b693258c..d2363a32 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -18,15 +18,9 @@ def get_cuda_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction): - total_memory = torch.xpu.get_device_properties(device).total_memory - device_id = device.index + total_free_memory, total_xpu_memory = torch.xpu.mem_get_info(device) memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) - free_memory = max( - 0, - int( - total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) - ), - ) + free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_xpu_memory) return free_memory