diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 57bea16b..c3929392 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -21,7 +21,7 @@ def get_cuda_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction): total_memory = torch.xpu.get_device_properties(device).total_memory - device_id = str(device)[4] + device_id = device.index query = f"xpu-smi dump -d {device_id} -m 18 -n 1" output = subprocess.check_output(query.split()).decode("utf-8").split("\n") used_memory = float(output[1].split(",")[-1]) * 1024 * 1024