mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
941d36f3fd
commit
d31fb62576
@ -41,6 +41,7 @@ class Profiler():
|
||||
filename = "server_events.json"
|
||||
|
||||
def __init__(self):
|
||||
self.step = 0
|
||||
self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0
|
||||
if self.enabled:
|
||||
# initialize the trace file
|
||||
@ -50,20 +51,17 @@ class Profiler():
|
||||
file_writer.start()
|
||||
|
||||
@contextmanager
|
||||
def record_event(self, type, name, args=None, util=None):
|
||||
def record_event(self, type, name, args=None, util=None, count_step=False):
|
||||
if self.enabled:
|
||||
start = time.time() * 1000000.0
|
||||
if util is not None:
|
||||
self.profiling_trace_events.put(json.dumps([{
|
||||
"pid": 1,
|
||||
"tid": self.event_tid["counter"],
|
||||
"ph": "C",
|
||||
"name": "util",
|
||||
"ts": start,
|
||||
"args": {
|
||||
"util": util["util"],
|
||||
}}]))
|
||||
self._add_util_event(util, start)
|
||||
|
||||
if count_step:
|
||||
if args is None:
|
||||
args = {}
|
||||
args["step"] = self.step
|
||||
self.step += 1
|
||||
event = {
|
||||
"pid": 1,
|
||||
"tid": self.event_tid[type],
|
||||
@ -80,4 +78,17 @@ class Profiler():
|
||||
|
||||
self.profiling_trace_events.put(json.dumps([event]))
|
||||
else:
|
||||
yield
|
||||
yield
|
||||
|
||||
def _add_util_event(self, util, start):
|
||||
util_event = {
|
||||
"pid": 1,
|
||||
"tid": self.event_tid["counter"],
|
||||
"ph": "C",
|
||||
"name": "util",
|
||||
"ts": start,
|
||||
"args": {
|
||||
"util": util["util"],
|
||||
}
|
||||
}
|
||||
self.profiling_trace_events.put(json.dumps([util_event]))
|
||||
|
@ -53,10 +53,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
async def FilterBatch(self, request, context):
|
||||
batch = self.cache.pop(request.batch_id)
|
||||
with self.profiler.record_event("external",
|
||||
"filter_batch",
|
||||
{"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
|
||||
{"util": len(batch.requests)}):
|
||||
with self.profiler.record_event(
|
||||
type="external",
|
||||
name="filter_batch",
|
||||
args={"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
|
||||
util={"util": len(batch.requests)}
|
||||
):
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||
filtered_batch = batch.filter(request.request_ids)
|
||||
@ -81,9 +83,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||
)
|
||||
with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}):
|
||||
|
||||
with self.profiler.record_event("internal", "generate_token"):
|
||||
with self.profiler.record_event(
|
||||
type="external",
|
||||
name="prefill",
|
||||
args={"batch_size": batch.batch_size, "sequence_length": batch.seq_length}
|
||||
):
|
||||
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
|
||||
generations, next_batch = self.model.generate_token([batch])
|
||||
self.cache.set(next_batch)
|
||||
|
||||
@ -94,10 +99,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
async def Decode(self, request, context):
|
||||
batch0 = self.cache.cache[request.batches[0].id]
|
||||
with self.profiler.record_event("external",
|
||||
"decode",
|
||||
{"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.input_ids.size(0)},
|
||||
{"util": len(batch0.requests)}):
|
||||
with self.profiler.record_event(
|
||||
type="external",
|
||||
name="decode",
|
||||
args={"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.batch_size},
|
||||
util={"util": len(batch0.requests)}
|
||||
):
|
||||
if len(request.batches) == 0:
|
||||
raise ValueError("Must provide at least one batch")
|
||||
|
||||
@ -111,7 +118,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
if len(batches) == 0:
|
||||
raise ValueError("All batches are empty")
|
||||
|
||||
with self.profiler.record_event("internal", "generate_token"):
|
||||
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
|
||||
generations, next_batch = self.model.generate_token(batches)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user