diff --git a/server/text_generation_server/profiler.py b/server/text_generation_server/profiler.py index 561445e7..e0248932 100644 --- a/server/text_generation_server/profiler.py +++ b/server/text_generation_server/profiler.py @@ -41,6 +41,7 @@ class Profiler(): filename = "server_events.json" def __init__(self): + self.step = 0 self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0 if self.enabled: # initialize the trace file @@ -50,20 +51,17 @@ class Profiler(): file_writer.start() @contextmanager - def record_event(self, type, name, args=None, util=None): + def record_event(self, type, name, args=None, util=None, count_step=False): if self.enabled: start = time.time() * 1000000.0 if util is not None: - self.profiling_trace_events.put(json.dumps([{ - "pid": 1, - "tid": self.event_tid["counter"], - "ph": "C", - "name": "util", - "ts": start, - "args": { - "util": util["util"], - }}])) + self._add_util_event(util, start) + if count_step: + if args is None: + args = {} + args["step"] = self.step + self.step += 1 event = { "pid": 1, "tid": self.event_tid[type], @@ -80,4 +78,17 @@ class Profiler(): self.profiling_trace_events.put(json.dumps([event])) else: - yield \ No newline at end of file + yield + + def _add_util_event(self, util, start): + util_event = { + "pid": 1, + "tid": self.event_tid["counter"], + "ph": "C", + "name": "util", + "ts": start, + "args": { + "util": util["util"], + } + } + self.profiling_trace_events.put(json.dumps([util_event])) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index dd64a93d..6f3e49f2 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -53,10 +53,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) - with self.profiler.record_event("external", - "filter_batch", - {"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]}, - {"util": len(batch.requests)}): + with self.profiler.record_event( + type="external", + name="filter_batch", + args={"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]}, + util={"util": len(batch.requests)} + ): if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") filtered_batch = batch.filter(request.request_ids) @@ -81,9 +83,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi ) - with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}): - - with self.profiler.record_event("internal", "generate_token"): + with self.profiler.record_event( + type="external", + name="prefill", + args={"batch_size": batch.batch_size, "sequence_length": batch.seq_length} + ): + with self.profiler.record_event(type="internal", name="generate_token", count_step=True): generations, next_batch = self.model.generate_token([batch]) self.cache.set(next_batch) @@ -94,10 +99,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Decode(self, request, context): batch0 = self.cache.cache[request.batches[0].id] - with self.profiler.record_event("external", - "decode", - {"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.input_ids.size(0)}, - {"util": len(batch0.requests)}): + with self.profiler.record_event( + type="external", + name="decode", + args={"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.batch_size}, + util={"util": len(batch0.requests)} + ): if len(request.batches) == 0: raise ValueError("Must provide at least one batch") @@ -111,7 +118,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) == 0: raise ValueError("All batches are empty") - with self.profiler.record_event("internal", "generate_token"): + with self.profiler.record_event(type="internal", name="generate_token", count_step=True): generations, next_batch = self.model.generate_token(batches) self.cache.set(next_batch)