Debugging utils (#14)

This commit is contained in:
madamczykhabana 2024-01-15 21:05:27 +01:00 committed by GitHub
parent a8c5b69e2c
commit 41c4f4fa41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 18 deletions

View File

@ -6,6 +6,8 @@ from grpc_status import rpc_status
from grpc_interceptor.server import AsyncServerInterceptor
from loguru import logger
from typing import Callable, Any
import traceback
import os
class ExceptionInterceptor(AsyncServerInterceptor):
@ -20,6 +22,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
response = method(request_or_iterator, context)
return await response
except Exception as err:
trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else ''
method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.")
@ -28,6 +31,6 @@ class ExceptionInterceptor(AsyncServerInterceptor):
await context.abort_with_status(
rpc_status.to_status(
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
)
)

View File

@ -1,6 +1,8 @@
import os
import tempfile
import itertools
import time
import glob
from text_generation_server.utils.tokens import batch_top_tokens
import torch
@ -12,7 +14,7 @@ from typing import Optional, Tuple, List, Type, Dict
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
import habana_frameworks.torch as htorch
from contextlib import nullcontext
from optimum.habana.utils import HabanaProfile
from optimum.habana.utils import HabanaProfile, to_gb_rounded
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
from optimum.habana.checkpoint_utils import (
@ -35,13 +37,31 @@ from loguru import logger
tracer = trace.get_tracer(__name__)
if 'GRAPH_VISUALIZATION' in os.environ:
for f in glob.glob('.graph_dumps/*'):
os.remove(f)
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
TRACE_FILENAME = os.environ.get('TRACE_FILENAME')
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
START_TS = None
def trace(txt):
if TRACE_FILENAME is not None:
print(txt, flush=True, file=open(TRACE_FILENAME, 'a'))
def count_hpu_graphs():
return len(glob.glob('.graph_dumps/*PreGraph*'))
def dbg_trace(tag, txt):
global START_TS
if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
if START_TS is None:
START_TS = time.perf_counter()
time_offset = time.perf_counter() - START_TS
mem_stats = htorch.hpu.memory.memory_stats()
mem_used = to_gb_rounded(mem_stats['InUse'])
max_mem_used = to_gb_rounded(mem_stats['MaxInUse'])
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB '
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))
def round_up(number, k):
@ -52,12 +72,6 @@ def batch_alloc(new_bs, tensor):
return tensor.new_empty((new_bs,) + tensor.shape[1:])
def to_tensors(indices, device):
def convert(idx):
return torch.tensor(idx, device=device)
return [[(convert(dst), convert(src)) for dst, src in batch_ind] for batch_ind in indices]
def move_data(dst_tensor, chunk_size, indices, src_tensors):
batch_dim = 0
bs = dst_tensor.size(batch_dim)
@ -172,7 +186,8 @@ class CausalLMBatch(Batch):
# FIXME: max_seq_len for non optimized code
max_input_length = max(req.input_length for req in requests)
offsets = [(max_input_length - b.input_length) for b in batches]
trace(f'RECOMBINE: bs:{new_bs} requests: {len(requests)} offsets: {offsets}')
scenario = 'CONCAT' if len(batches) > 1 else 'FILTER'
dbg_trace(scenario, f'bs:{[b.input_ids.size(0) for b in batches]}->{new_bs} num_reqs:{[len(b.requests) for b in batches]}->{len(requests)} offsets:{offsets}')
max_seq_len = batches[0].attention_mask.size(1)
input_length = max(r.input_length for r in requests)
@ -268,7 +283,7 @@ class CausalLMBatch(Batch):
device: torch.device,
is_optimized_for_gaudi: bool = False,
) -> "CausalLMBatch":
trace(f'NEW BATCH: ({len(pb.requests)}){[req.id for req in pb.requests]}')
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
max_input_length = max(r.data.truncate for r in requests)
@ -355,13 +370,11 @@ class CausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
trace("FILTER")
return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
trace('CONCAT')
return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi)
def __len__(self):
@ -536,7 +549,9 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
trace(f'GENERATE ({len(batch.requests)}){[r.data.id for r in batch.requests]}, {batch.input_ids.shape}')
prefill = batch.past_key_values is None
scenario = 'PREFILL' if prefill else 'GENERATE'
dbg_trace(scenario, f'bs:{batch.input_ids.size(0)} num_reqs:{len(batch.requests)} seq_len:{batch.input_ids.shape[1]}')
self.step = self.step + 1
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
self.hb_profer.stop()
@ -550,7 +565,6 @@ class CausalLM(Model):
# slice the attention mask to the correct shape
# TODO fix me!
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
prefill = batch.past_key_values is None
if batch.past_key_values:
if token_idx is not None:
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)