mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
941d36f3fd
commit
d31fb62576
@ -41,6 +41,7 @@ class Profiler():
|
|||||||
filename = "server_events.json"
|
filename = "server_events.json"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.step = 0
|
||||||
self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0
|
self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
# initialize the trace file
|
# initialize the trace file
|
||||||
@ -50,20 +51,17 @@ class Profiler():
|
|||||||
file_writer.start()
|
file_writer.start()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def record_event(self, type, name, args=None, util=None):
|
def record_event(self, type, name, args=None, util=None, count_step=False):
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
start = time.time() * 1000000.0
|
start = time.time() * 1000000.0
|
||||||
if util is not None:
|
if util is not None:
|
||||||
self.profiling_trace_events.put(json.dumps([{
|
self._add_util_event(util, start)
|
||||||
"pid": 1,
|
|
||||||
"tid": self.event_tid["counter"],
|
|
||||||
"ph": "C",
|
|
||||||
"name": "util",
|
|
||||||
"ts": start,
|
|
||||||
"args": {
|
|
||||||
"util": util["util"],
|
|
||||||
}}]))
|
|
||||||
|
|
||||||
|
if count_step:
|
||||||
|
if args is None:
|
||||||
|
args = {}
|
||||||
|
args["step"] = self.step
|
||||||
|
self.step += 1
|
||||||
event = {
|
event = {
|
||||||
"pid": 1,
|
"pid": 1,
|
||||||
"tid": self.event_tid[type],
|
"tid": self.event_tid[type],
|
||||||
@ -81,3 +79,16 @@ class Profiler():
|
|||||||
self.profiling_trace_events.put(json.dumps([event]))
|
self.profiling_trace_events.put(json.dumps([event]))
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
def _add_util_event(self, util, start):
|
||||||
|
util_event = {
|
||||||
|
"pid": 1,
|
||||||
|
"tid": self.event_tid["counter"],
|
||||||
|
"ph": "C",
|
||||||
|
"name": "util",
|
||||||
|
"ts": start,
|
||||||
|
"args": {
|
||||||
|
"util": util["util"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.profiling_trace_events.put(json.dumps([util_event]))
|
||||||
|
@ -53,10 +53,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
async def FilterBatch(self, request, context):
|
async def FilterBatch(self, request, context):
|
||||||
batch = self.cache.pop(request.batch_id)
|
batch = self.cache.pop(request.batch_id)
|
||||||
with self.profiler.record_event("external",
|
with self.profiler.record_event(
|
||||||
"filter_batch",
|
type="external",
|
||||||
{"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
|
name="filter_batch",
|
||||||
{"util": len(batch.requests)}):
|
args={"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
|
||||||
|
util={"util": len(batch.requests)}
|
||||||
|
):
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||||
filtered_batch = batch.filter(request.request_ids)
|
filtered_batch = batch.filter(request.request_ids)
|
||||||
@ -81,9 +83,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||||
)
|
)
|
||||||
with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}):
|
with self.profiler.record_event(
|
||||||
|
type="external",
|
||||||
with self.profiler.record_event("internal", "generate_token"):
|
name="prefill",
|
||||||
|
args={"batch_size": batch.batch_size, "sequence_length": batch.seq_length}
|
||||||
|
):
|
||||||
|
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
|
||||||
generations, next_batch = self.model.generate_token([batch])
|
generations, next_batch = self.model.generate_token([batch])
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
@ -94,10 +99,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
async def Decode(self, request, context):
|
async def Decode(self, request, context):
|
||||||
batch0 = self.cache.cache[request.batches[0].id]
|
batch0 = self.cache.cache[request.batches[0].id]
|
||||||
with self.profiler.record_event("external",
|
with self.profiler.record_event(
|
||||||
"decode",
|
type="external",
|
||||||
{"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.input_ids.size(0)},
|
name="decode",
|
||||||
{"util": len(batch0.requests)}):
|
args={"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.batch_size},
|
||||||
|
util={"util": len(batch0.requests)}
|
||||||
|
):
|
||||||
if len(request.batches) == 0:
|
if len(request.batches) == 0:
|
||||||
raise ValueError("Must provide at least one batch")
|
raise ValueError("Must provide at least one batch")
|
||||||
|
|
||||||
@ -111,7 +118,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
raise ValueError("All batches are empty")
|
raise ValueError("All batches are empty")
|
||||||
|
|
||||||
with self.profiler.record_event("internal", "generate_token"):
|
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
|
||||||
generations, next_batch = self.model.generate_token(batches)
|
generations, next_batch = self.model.generate_token(batches)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user