diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 725105f3..515c2c9d 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -6,6 +6,8 @@ from grpc_status import rpc_status from grpc_interceptor.server import AsyncServerInterceptor from loguru import logger from typing import Callable, Any +import traceback +import os class ExceptionInterceptor(AsyncServerInterceptor): @@ -20,6 +22,7 @@ class ExceptionInterceptor(AsyncServerInterceptor): response = method(request_or_iterator, context) return await response except Exception as err: + trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else '' method_name = method_name.split("/")[-1] logger.exception(f"Method {method_name} encountered an error.") @@ -28,6 +31,6 @@ class ExceptionInterceptor(AsyncServerInterceptor): await context.abort_with_status( rpc_status.to_status( - status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) + status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace) ) ) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 3cfa7d6c..533d35f0 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,6 +1,8 @@ import os import tempfile import itertools +import time +import glob from text_generation_server.utils.tokens import batch_top_tokens import torch @@ -12,7 +14,7 @@ from typing import Optional, Tuple, List, Type, Dict from habana_frameworks.torch.hpu import wrap_in_hpu_graph import habana_frameworks.torch as htorch from contextlib import nullcontext -from optimum.habana.utils import HabanaProfile +from optimum.habana.utils import HabanaProfile, to_gb_rounded from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.checkpoint_utils import ( @@ -35,13 +37,31 @@ from loguru import logger tracer = trace.get_tracer(__name__) +if 'GRAPH_VISUALIZATION' in os.environ: + for f in glob.glob('.graph_dumps/*'): + os.remove(f) + BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) -TRACE_FILENAME = os.environ.get('TRACE_FILENAME') +DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') +START_TS = None -def trace(txt): - if TRACE_FILENAME is not None: - print(txt, flush=True, file=open(TRACE_FILENAME, 'a')) + +def count_hpu_graphs(): + return len(glob.glob('.graph_dumps/*PreGraph*')) + + +def dbg_trace(tag, txt): + global START_TS + if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0: + if START_TS is None: + START_TS = time.perf_counter() + time_offset = time.perf_counter() - START_TS + mem_stats = htorch.hpu.memory.memory_stats() + mem_used = to_gb_rounded(mem_stats['InUse']) + max_mem_used = to_gb_rounded(mem_stats['MaxInUse']) + print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB ' + f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a')) def round_up(number, k): @@ -52,12 +72,6 @@ def batch_alloc(new_bs, tensor): return tensor.new_empty((new_bs,) + tensor.shape[1:]) -def to_tensors(indices, device): - def convert(idx): - return torch.tensor(idx, device=device) - return [[(convert(dst), convert(src)) for dst, src in batch_ind] for batch_ind in indices] - - def move_data(dst_tensor, chunk_size, indices, src_tensors): batch_dim = 0 bs = dst_tensor.size(batch_dim) @@ -172,7 +186,8 @@ class CausalLMBatch(Batch): # FIXME: max_seq_len for non optimized code max_input_length = max(req.input_length for req in requests) offsets = [(max_input_length - b.input_length) for b in batches] - trace(f'RECOMBINE: bs:{new_bs} requests: {len(requests)} offsets: {offsets}') + scenario = 'CONCAT' if len(batches) > 1 else 'FILTER' + dbg_trace(scenario, f'bs:{[b.input_ids.size(0) for b in batches]}->{new_bs} num_reqs:{[len(b.requests) for b in batches]}->{len(requests)} offsets:{offsets}') max_seq_len = batches[0].attention_mask.size(1) input_length = max(r.input_length for r in requests) @@ -268,7 +283,7 @@ class CausalLMBatch(Batch): device: torch.device, is_optimized_for_gaudi: bool = False, ) -> "CausalLMBatch": - trace(f'NEW BATCH: ({len(pb.requests)}){[req.id for req in pb.requests]}') + dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] max_input_length = max(r.data.truncate for r in requests) @@ -355,13 +370,11 @@ class CausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]: - trace("FILTER") return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi) @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": - trace('CONCAT') return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi) def __len__(self): @@ -536,7 +549,9 @@ class CausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]: - trace(f'GENERATE ({len(batch.requests)}){[r.data.id for r in batch.requests]}, {batch.input_ids.shape}') + prefill = batch.past_key_values is None + scenario = 'PREFILL' if prefill else 'GENERATE' + dbg_trace(scenario, f'bs:{batch.input_ids.size(0)} num_reqs:{len(batch.requests)} seq_len:{batch.input_ids.shape[1]}') self.step = self.step + 1 if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps: self.hb_profer.stop() @@ -550,7 +565,6 @@ class CausalLM(Model): # slice the attention mask to the correct shape # TODO fix me! attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - prefill = batch.past_key_values is None if batch.past_key_values: if token_idx is not None: input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)