update get xpu memory api

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-12 00:37:17 -07:00
parent cdc5380f2b
commit 468da545af

View File

@ -18,15 +18,9 @@ def get_cuda_free_memory(device, memory_fraction):
def get_xpu_free_memory(device, memory_fraction):
total_memory = torch.xpu.get_device_properties(device).total_memory
device_id = device.index
total_free_memory, total_xpu_memory = torch.xpu.mem_get_info(device)
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
free_memory = max(
0,
int(
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
),
)
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_xpu_memory)
return free_memory