diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 515c2c9d..3dab7695 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -29,6 +29,8 @@ class ExceptionInterceptor(AsyncServerInterceptor): if torch.cuda.is_available(): torch.cuda.empty_cache() + from .utils.debug import dbg_trace + dbg_trace('EXCEPTION', traceback.format_exc()) await context.abort_with_status( rpc_status.to_status( status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 5cc35165..17364e82 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,8 +1,6 @@ import os import tempfile import itertools -import time -import glob import bisect import math @@ -17,7 +15,7 @@ import text_generation_server.habana_quantization_env as hq_env import habana_frameworks.torch as htorch from habana_frameworks.torch.hpu import wrap_in_hpu_graph from contextlib import nullcontext -from optimum.habana.utils import HabanaProfile, to_gb_rounded +from optimum.habana.utils import HabanaProfile from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.checkpoint_utils import ( @@ -37,42 +35,20 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent +from text_generation_server.utils.debug import dbg_trace from loguru import logger from functools import wraps tracer = trace.get_tracer(__name__) -if 'GRAPH_VISUALIZATION' in os.environ: - for f in glob.glob('.graph_dumps/*'): - os.remove(f) - MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", "0")) BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) -DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') -START_TS = None CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -def count_hpu_graphs(): - return len(glob.glob('.graph_dumps/*PreGraph*')) - - -def dbg_trace(tag, txt): - global START_TS - if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0: - if START_TS is None: - START_TS = time.perf_counter() - time_offset = time.perf_counter() - START_TS - mem_stats = htorch.hpu.memory.memory_stats() - mem_used = to_gb_rounded(mem_stats['InUse']) - max_mem_used = to_gb_rounded(mem_stats['MaxInUse']) - print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB ' - f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a')) - - def round_up(number, k): return (number + k - 1) // k * k diff --git a/server/text_generation_server/utils/debug.py b/server/text_generation_server/utils/debug.py new file mode 100644 index 00000000..1ecaca9d --- /dev/null +++ b/server/text_generation_server/utils/debug.py @@ -0,0 +1,29 @@ +import os +import glob +import time + +from optimum.habana.utils import to_gb_rounded +import habana_frameworks.torch as htorch + +START_TS = None +DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') +if 'GRAPH_VISUALIZATION' in os.environ: + for f in glob.glob('.graph_dumps/*'): + os.remove(f) + + +def count_hpu_graphs(): + return len(glob.glob('.graph_dumps/*PreGraph*')) + + +def dbg_trace(tag, txt): + global START_TS + if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0: + if START_TS is None: + START_TS = time.perf_counter() + time_offset = time.perf_counter() - START_TS + mem_stats = htorch.hpu.memory.memory_stats() + mem_used = to_gb_rounded(mem_stats['InUse']) + max_mem_used = to_gb_rounded(mem_stats['MaxInUse']) + print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB ' + f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))