diff --git a/backends/gaudi/server/text_generation_server/utils/debug.py b/backends/gaudi/server/text_generation_server/utils/debug.py index 8bbcad6a..690da54e 100644 --- a/backends/gaudi/server/text_generation_server/utils/debug.py +++ b/backends/gaudi/server/text_generation_server/utils/debug.py @@ -4,8 +4,8 @@ import os import glob import time -from optimum.habana.utils import to_gb_rounded import habana_frameworks.torch as htorch +import numpy as np START_TS = None DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME") @@ -14,6 +14,19 @@ if "GRAPH_VISUALIZATION" in os.environ: os.remove(f) +def to_gb_rounded(mem: float) -> float: + """ + Rounds and converts to GB. + + Args: + mem (float): memory in bytes + + Returns: + float: memory in GB rounded to the second decimal + """ + return np.round(mem / 1024**3, 2) + + def count_hpu_graphs(): return len(glob.glob(".graph_dumps/*PreGraph*"))