mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
Debugging utils (#14)
This commit is contained in:
parent
a8c5b69e2c
commit
41c4f4fa41
@ -6,6 +6,8 @@ from grpc_status import rpc_status
|
||||
from grpc_interceptor.server import AsyncServerInterceptor
|
||||
from loguru import logger
|
||||
from typing import Callable, Any
|
||||
import traceback
|
||||
import os
|
||||
|
||||
|
||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
@ -20,6 +22,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
response = method(request_or_iterator, context)
|
||||
return await response
|
||||
except Exception as err:
|
||||
trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else ''
|
||||
method_name = method_name.split("/")[-1]
|
||||
logger.exception(f"Method {method_name} encountered an error.")
|
||||
|
||||
@ -28,6 +31,6 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
|
||||
await context.abort_with_status(
|
||||
rpc_status.to_status(
|
||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
|
||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
|
||||
)
|
||||
)
|
||||
|
@ -1,6 +1,8 @@
|
||||
import os
|
||||
import tempfile
|
||||
import itertools
|
||||
import time
|
||||
import glob
|
||||
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
@ -12,7 +14,7 @@ from typing import Optional, Tuple, List, Type, Dict
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
import habana_frameworks.torch as htorch
|
||||
from contextlib import nullcontext
|
||||
from optimum.habana.utils import HabanaProfile
|
||||
from optimum.habana.utils import HabanaProfile, to_gb_rounded
|
||||
|
||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||
from optimum.habana.checkpoint_utils import (
|
||||
@ -35,13 +37,31 @@ from loguru import logger
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
if 'GRAPH_VISUALIZATION' in os.environ:
|
||||
for f in glob.glob('.graph_dumps/*'):
|
||||
os.remove(f)
|
||||
|
||||
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||
TRACE_FILENAME = os.environ.get('TRACE_FILENAME')
|
||||
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
|
||||
START_TS = None
|
||||
|
||||
def trace(txt):
|
||||
if TRACE_FILENAME is not None:
|
||||
print(txt, flush=True, file=open(TRACE_FILENAME, 'a'))
|
||||
|
||||
def count_hpu_graphs():
|
||||
return len(glob.glob('.graph_dumps/*PreGraph*'))
|
||||
|
||||
|
||||
def dbg_trace(tag, txt):
|
||||
global START_TS
|
||||
if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
|
||||
if START_TS is None:
|
||||
START_TS = time.perf_counter()
|
||||
time_offset = time.perf_counter() - START_TS
|
||||
mem_stats = htorch.hpu.memory.memory_stats()
|
||||
mem_used = to_gb_rounded(mem_stats['InUse'])
|
||||
max_mem_used = to_gb_rounded(mem_stats['MaxInUse'])
|
||||
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB '
|
||||
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))
|
||||
|
||||
|
||||
def round_up(number, k):
|
||||
@ -52,12 +72,6 @@ def batch_alloc(new_bs, tensor):
|
||||
return tensor.new_empty((new_bs,) + tensor.shape[1:])
|
||||
|
||||
|
||||
def to_tensors(indices, device):
|
||||
def convert(idx):
|
||||
return torch.tensor(idx, device=device)
|
||||
return [[(convert(dst), convert(src)) for dst, src in batch_ind] for batch_ind in indices]
|
||||
|
||||
|
||||
def move_data(dst_tensor, chunk_size, indices, src_tensors):
|
||||
batch_dim = 0
|
||||
bs = dst_tensor.size(batch_dim)
|
||||
@ -172,7 +186,8 @@ class CausalLMBatch(Batch):
|
||||
# FIXME: max_seq_len for non optimized code
|
||||
max_input_length = max(req.input_length for req in requests)
|
||||
offsets = [(max_input_length - b.input_length) for b in batches]
|
||||
trace(f'RECOMBINE: bs:{new_bs} requests: {len(requests)} offsets: {offsets}')
|
||||
scenario = 'CONCAT' if len(batches) > 1 else 'FILTER'
|
||||
dbg_trace(scenario, f'bs:{[b.input_ids.size(0) for b in batches]}->{new_bs} num_reqs:{[len(b.requests) for b in batches]}->{len(requests)} offsets:{offsets}')
|
||||
|
||||
max_seq_len = batches[0].attention_mask.size(1)
|
||||
input_length = max(r.input_length for r in requests)
|
||||
@ -268,7 +283,7 @@ class CausalLMBatch(Batch):
|
||||
device: torch.device,
|
||||
is_optimized_for_gaudi: bool = False,
|
||||
) -> "CausalLMBatch":
|
||||
trace(f'NEW BATCH: ({len(pb.requests)}){[req.id for req in pb.requests]}')
|
||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
||||
|
||||
max_input_length = max(r.data.truncate for r in requests)
|
||||
@ -355,13 +370,11 @@ class CausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
|
||||
trace("FILTER")
|
||||
return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||
trace('CONCAT')
|
||||
return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi)
|
||||
|
||||
def __len__(self):
|
||||
@ -536,7 +549,9 @@ class CausalLM(Model):
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
trace(f'GENERATE ({len(batch.requests)}){[r.data.id for r in batch.requests]}, {batch.input_ids.shape}')
|
||||
prefill = batch.past_key_values is None
|
||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||
dbg_trace(scenario, f'bs:{batch.input_ids.size(0)} num_reqs:{len(batch.requests)} seq_len:{batch.input_ids.shape[1]}')
|
||||
self.step = self.step + 1
|
||||
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
||||
self.hb_profer.stop()
|
||||
@ -550,7 +565,6 @@ class CausalLM(Model):
|
||||
# slice the attention mask to the correct shape
|
||||
# TODO fix me!
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
prefill = batch.past_key_values is None
|
||||
if batch.past_key_values:
|
||||
if token_idx is not None:
|
||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||
|
Loading…
Reference in New Issue
Block a user