mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Debugging utils (#14)
This commit is contained in:
parent
a8c5b69e2c
commit
41c4f4fa41
@ -6,6 +6,8 @@ from grpc_status import rpc_status
|
|||||||
from grpc_interceptor.server import AsyncServerInterceptor
|
from grpc_interceptor.server import AsyncServerInterceptor
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any
|
||||||
|
import traceback
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
@ -20,6 +22,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||||||
response = method(request_or_iterator, context)
|
response = method(request_or_iterator, context)
|
||||||
return await response
|
return await response
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else ''
|
||||||
method_name = method_name.split("/")[-1]
|
method_name = method_name.split("/")[-1]
|
||||||
logger.exception(f"Method {method_name} encountered an error.")
|
logger.exception(f"Method {method_name} encountered an error.")
|
||||||
|
|
||||||
@ -28,6 +31,6 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||||||
|
|
||||||
await context.abort_with_status(
|
await context.abort_with_status(
|
||||||
rpc_status.to_status(
|
rpc_status.to_status(
|
||||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
|
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import itertools
|
import itertools
|
||||||
|
import time
|
||||||
|
import glob
|
||||||
|
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
@ -12,7 +14,7 @@ from typing import Optional, Tuple, List, Type, Dict
|
|||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from optimum.habana.utils import HabanaProfile
|
from optimum.habana.utils import HabanaProfile, to_gb_rounded
|
||||||
|
|
||||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
from optimum.habana.checkpoint_utils import (
|
from optimum.habana.checkpoint_utils import (
|
||||||
@ -35,13 +37,31 @@ from loguru import logger
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
if 'GRAPH_VISUALIZATION' in os.environ:
|
||||||
|
for f in glob.glob('.graph_dumps/*'):
|
||||||
|
os.remove(f)
|
||||||
|
|
||||||
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||||
TRACE_FILENAME = os.environ.get('TRACE_FILENAME')
|
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
|
||||||
|
START_TS = None
|
||||||
|
|
||||||
def trace(txt):
|
|
||||||
if TRACE_FILENAME is not None:
|
def count_hpu_graphs():
|
||||||
print(txt, flush=True, file=open(TRACE_FILENAME, 'a'))
|
return len(glob.glob('.graph_dumps/*PreGraph*'))
|
||||||
|
|
||||||
|
|
||||||
|
def dbg_trace(tag, txt):
|
||||||
|
global START_TS
|
||||||
|
if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
|
||||||
|
if START_TS is None:
|
||||||
|
START_TS = time.perf_counter()
|
||||||
|
time_offset = time.perf_counter() - START_TS
|
||||||
|
mem_stats = htorch.hpu.memory.memory_stats()
|
||||||
|
mem_used = to_gb_rounded(mem_stats['InUse'])
|
||||||
|
max_mem_used = to_gb_rounded(mem_stats['MaxInUse'])
|
||||||
|
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB '
|
||||||
|
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))
|
||||||
|
|
||||||
|
|
||||||
def round_up(number, k):
|
def round_up(number, k):
|
||||||
@ -52,12 +72,6 @@ def batch_alloc(new_bs, tensor):
|
|||||||
return tensor.new_empty((new_bs,) + tensor.shape[1:])
|
return tensor.new_empty((new_bs,) + tensor.shape[1:])
|
||||||
|
|
||||||
|
|
||||||
def to_tensors(indices, device):
|
|
||||||
def convert(idx):
|
|
||||||
return torch.tensor(idx, device=device)
|
|
||||||
return [[(convert(dst), convert(src)) for dst, src in batch_ind] for batch_ind in indices]
|
|
||||||
|
|
||||||
|
|
||||||
def move_data(dst_tensor, chunk_size, indices, src_tensors):
|
def move_data(dst_tensor, chunk_size, indices, src_tensors):
|
||||||
batch_dim = 0
|
batch_dim = 0
|
||||||
bs = dst_tensor.size(batch_dim)
|
bs = dst_tensor.size(batch_dim)
|
||||||
@ -172,7 +186,8 @@ class CausalLMBatch(Batch):
|
|||||||
# FIXME: max_seq_len for non optimized code
|
# FIXME: max_seq_len for non optimized code
|
||||||
max_input_length = max(req.input_length for req in requests)
|
max_input_length = max(req.input_length for req in requests)
|
||||||
offsets = [(max_input_length - b.input_length) for b in batches]
|
offsets = [(max_input_length - b.input_length) for b in batches]
|
||||||
trace(f'RECOMBINE: bs:{new_bs} requests: {len(requests)} offsets: {offsets}')
|
scenario = 'CONCAT' if len(batches) > 1 else 'FILTER'
|
||||||
|
dbg_trace(scenario, f'bs:{[b.input_ids.size(0) for b in batches]}->{new_bs} num_reqs:{[len(b.requests) for b in batches]}->{len(requests)} offsets:{offsets}')
|
||||||
|
|
||||||
max_seq_len = batches[0].attention_mask.size(1)
|
max_seq_len = batches[0].attention_mask.size(1)
|
||||||
input_length = max(r.input_length for r in requests)
|
input_length = max(r.input_length for r in requests)
|
||||||
@ -268,7 +283,7 @@ class CausalLMBatch(Batch):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
is_optimized_for_gaudi: bool = False,
|
is_optimized_for_gaudi: bool = False,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
trace(f'NEW BATCH: ({len(pb.requests)}){[req.id for req in pb.requests]}')
|
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
||||||
|
|
||||||
max_input_length = max(r.data.truncate for r in requests)
|
max_input_length = max(r.data.truncate for r in requests)
|
||||||
@ -355,13 +370,11 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
|
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
|
||||||
trace("FILTER")
|
|
||||||
return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi)
|
return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||||
trace('CONCAT')
|
|
||||||
return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi)
|
return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -536,7 +549,9 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
trace(f'GENERATE ({len(batch.requests)}){[r.data.id for r in batch.requests]}, {batch.input_ids.shape}')
|
prefill = batch.past_key_values is None
|
||||||
|
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||||
|
dbg_trace(scenario, f'bs:{batch.input_ids.size(0)} num_reqs:{len(batch.requests)} seq_len:{batch.input_ids.shape[1]}')
|
||||||
self.step = self.step + 1
|
self.step = self.step + 1
|
||||||
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
||||||
self.hb_profer.stop()
|
self.hb_profer.stop()
|
||||||
@ -550,7 +565,6 @@ class CausalLM(Model):
|
|||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
# TODO fix me!
|
# TODO fix me!
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
prefill = batch.past_key_values is None
|
|
||||||
if batch.past_key_values:
|
if batch.past_key_values:
|
||||||
if token_idx is not None:
|
if token_idx is not None:
|
||||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user