mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com>
This commit is contained in:
parent
c7ccfb87ff
commit
212136dff8
@ -29,6 +29,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
from .utils.debug import dbg_trace
|
||||
dbg_trace('EXCEPTION', traceback.format_exc())
|
||||
await context.abort_with_status(
|
||||
rpc_status.to_status(
|
||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
|
||||
|
@ -1,8 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
import itertools
|
||||
import time
|
||||
import glob
|
||||
import bisect
|
||||
import math
|
||||
|
||||
@ -17,7 +15,7 @@ import text_generation_server.habana_quantization_env as hq_env
|
||||
import habana_frameworks.torch as htorch
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
from contextlib import nullcontext
|
||||
from optimum.habana.utils import HabanaProfile, to_gb_rounded
|
||||
from optimum.habana.utils import HabanaProfile
|
||||
|
||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||
from optimum.habana.checkpoint_utils import (
|
||||
@ -37,42 +35,20 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
|
||||
from text_generation_server.utils.debug import dbg_trace
|
||||
from loguru import logger
|
||||
from functools import wraps
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
if 'GRAPH_VISUALIZATION' in os.environ:
|
||||
for f in glob.glob('.graph_dumps/*'):
|
||||
os.remove(f)
|
||||
|
||||
MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", "0"))
|
||||
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
|
||||
START_TS = None
|
||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
|
||||
def count_hpu_graphs():
|
||||
return len(glob.glob('.graph_dumps/*PreGraph*'))
|
||||
|
||||
|
||||
def dbg_trace(tag, txt):
|
||||
global START_TS
|
||||
if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
|
||||
if START_TS is None:
|
||||
START_TS = time.perf_counter()
|
||||
time_offset = time.perf_counter() - START_TS
|
||||
mem_stats = htorch.hpu.memory.memory_stats()
|
||||
mem_used = to_gb_rounded(mem_stats['InUse'])
|
||||
max_mem_used = to_gb_rounded(mem_stats['MaxInUse'])
|
||||
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB '
|
||||
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))
|
||||
|
||||
|
||||
def round_up(number, k):
|
||||
return (number + k - 1) // k * k
|
||||
|
||||
|
29
server/text_generation_server/utils/debug.py
Normal file
29
server/text_generation_server/utils/debug.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
|
||||
from optimum.habana.utils import to_gb_rounded
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
START_TS = None
|
||||
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
|
||||
if 'GRAPH_VISUALIZATION' in os.environ:
|
||||
for f in glob.glob('.graph_dumps/*'):
|
||||
os.remove(f)
|
||||
|
||||
|
||||
def count_hpu_graphs():
|
||||
return len(glob.glob('.graph_dumps/*PreGraph*'))
|
||||
|
||||
|
||||
def dbg_trace(tag, txt):
|
||||
global START_TS
|
||||
if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
|
||||
if START_TS is None:
|
||||
START_TS = time.perf_counter()
|
||||
time_offset = time.perf_counter() - START_TS
|
||||
mem_stats = htorch.hpu.memory.memory_stats()
|
||||
mem_used = to_gb_rounded(mem_stats['InUse'])
|
||||
max_mem_used = to_gb_rounded(mem_stats['MaxInUse'])
|
||||
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB '
|
||||
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))
|
Loading…
Reference in New Issue
Block a user