Add more info to high-level profiler events (#46) (#79)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-02-28 09:55:50 +01:00 committed by GitHub
parent 941d36f3fd
commit d31fb62576
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 23 deletions

View File

@ -41,6 +41,7 @@ class Profiler():
filename = "server_events.json"
def __init__(self):
self.step = 0
self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0
if self.enabled:
# initialize the trace file
@ -50,20 +51,17 @@ class Profiler():
file_writer.start()
@contextmanager
def record_event(self, type, name, args=None, util=None):
def record_event(self, type, name, args=None, util=None, count_step=False):
if self.enabled:
start = time.time() * 1000000.0
if util is not None:
self.profiling_trace_events.put(json.dumps([{
"pid": 1,
"tid": self.event_tid["counter"],
"ph": "C",
"name": "util",
"ts": start,
"args": {
"util": util["util"],
}}]))
self._add_util_event(util, start)
if count_step:
if args is None:
args = {}
args["step"] = self.step
self.step += 1
event = {
"pid": 1,
"tid": self.event_tid[type],
@ -80,4 +78,17 @@ class Profiler():
self.profiling_trace_events.put(json.dumps([event]))
else:
yield
yield
def _add_util_event(self, util, start):
util_event = {
"pid": 1,
"tid": self.event_tid["counter"],
"ph": "C",
"name": "util",
"ts": start,
"args": {
"util": util["util"],
}
}
self.profiling_trace_events.put(json.dumps([util_event]))

View File

@ -53,10 +53,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id)
with self.profiler.record_event("external",
"filter_batch",
{"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
{"util": len(batch.requests)}):
with self.profiler.record_event(
type="external",
name="filter_batch",
args={"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
util={"util": len(batch.requests)}
):
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids)
@ -81,9 +83,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
)
with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}):
with self.profiler.record_event("internal", "generate_token"):
with self.profiler.record_event(
type="external",
name="prefill",
args={"batch_size": batch.batch_size, "sequence_length": batch.seq_length}
):
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
generations, next_batch = self.model.generate_token([batch])
self.cache.set(next_batch)
@ -94,10 +99,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Decode(self, request, context):
batch0 = self.cache.cache[request.batches[0].id]
with self.profiler.record_event("external",
"decode",
{"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.input_ids.size(0)},
{"util": len(batch0.requests)}):
with self.profiler.record_event(
type="external",
name="decode",
args={"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.batch_size},
util={"util": len(batch0.requests)}
):
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
@ -111,7 +118,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) == 0:
raise ValueError("All batches are empty")
with self.profiler.record_event("internal", "generate_token"):
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
generations, next_batch = self.model.generate_token(batches)
self.cache.set(next_batch)