mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
High-level server profiler (#13)
This commit is contained in:
parent
41c4f4fa41
commit
c459c86f88
@ -79,6 +79,7 @@ Environment Variables Added:
|
|||||||
| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
|
| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
|
||||||
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||||
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||||
|
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
83
server/text_generation_server/profiler.py
Normal file
83
server/text_generation_server/profiler.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
|
class FileWriter(threading.Thread):
|
||||||
|
def __init__(self, filename, write_queue):
|
||||||
|
super().__init__()
|
||||||
|
self.filename = filename
|
||||||
|
self.write_queue = write_queue
|
||||||
|
self.daemon = True
|
||||||
|
self.timer_event = threading.Event()
|
||||||
|
|
||||||
|
def _drain_write_queue(self):
|
||||||
|
content = ""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
element = self.write_queue.get_nowait()
|
||||||
|
content += element
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
return content
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
# don't check the queue too often
|
||||||
|
while not self.timer_event.wait(1):
|
||||||
|
# Block and wait for the next item in the queue
|
||||||
|
content = self.write_queue.get()
|
||||||
|
# Collect any other items in the queue
|
||||||
|
content += self._drain_write_queue()
|
||||||
|
|
||||||
|
with open(self.filename, "a") as outfile:
|
||||||
|
outfile.write(content)
|
||||||
|
|
||||||
|
class Profiler():
|
||||||
|
profiling_trace_events = queue.Queue()
|
||||||
|
event_tid = {"counter": 1, "external": 2, "internal": 3, "own": 4}
|
||||||
|
filename = "server_events.json"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0
|
||||||
|
if self.enabled:
|
||||||
|
# initialize the trace file
|
||||||
|
with open(self.filename, "w") as outfile:
|
||||||
|
outfile.write('{"traceEvents": ')
|
||||||
|
file_writer = FileWriter(self.filename, self.profiling_trace_events)
|
||||||
|
file_writer.start()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def record_event(self, type, name, args=None, util=None):
|
||||||
|
if self.enabled:
|
||||||
|
start = time.time() * 1000000.0
|
||||||
|
if util is not None:
|
||||||
|
self.profiling_trace_events.put(json.dumps([{
|
||||||
|
"pid": 1,
|
||||||
|
"tid": self.event_tid["counter"],
|
||||||
|
"ph": "C",
|
||||||
|
"name": "util",
|
||||||
|
"ts": start,
|
||||||
|
"args": {
|
||||||
|
"util": util["util"],
|
||||||
|
}}]))
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"pid": 1,
|
||||||
|
"tid": self.event_tid[type],
|
||||||
|
"ph": "X",
|
||||||
|
"name": name,
|
||||||
|
"ts": start,
|
||||||
|
"dur": None,
|
||||||
|
"args": args
|
||||||
|
}
|
||||||
|
yield
|
||||||
|
|
||||||
|
end = time.time() * 1000000.0
|
||||||
|
event["dur"] = end - start
|
||||||
|
|
||||||
|
self.profiling_trace_events.put(json.dumps([event]))
|
||||||
|
else:
|
||||||
|
yield
|
@ -16,9 +16,12 @@ from text_generation_server.models import Model, get_model
|
|||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
|
|
||||||
|
from .profiler import Profiler
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||||
|
self.profiler = Profiler()
|
||||||
|
with self.profiler.record_event("external", "init"):
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.model = model
|
self.model = model
|
||||||
self.server_urls = server_urls
|
self.server_urls = server_urls
|
||||||
@ -41,6 +44,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
|
|
||||||
async def ClearCache(self, request, context):
|
async def ClearCache(self, request, context):
|
||||||
|
with self.profiler.record_event("external", "clear_cache"):
|
||||||
if request.HasField("id"):
|
if request.HasField("id"):
|
||||||
self.cache.delete(request.id)
|
self.cache.delete(request.id)
|
||||||
else:
|
else:
|
||||||
@ -49,6 +53,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
async def FilterBatch(self, request, context):
|
async def FilterBatch(self, request, context):
|
||||||
batch = self.cache.pop(request.batch_id)
|
batch = self.cache.pop(request.batch_id)
|
||||||
|
with self.profiler.record_event("external",
|
||||||
|
"filter_batch",
|
||||||
|
{"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
|
||||||
|
{"util": len(batch.requests)}):
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||||
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi)
|
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi)
|
||||||
@ -57,6 +65,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
with self.profiler.record_event("external", "warmup"):
|
||||||
# batch = self.model.batch_type.from_pb(
|
# batch = self.model.batch_type.from_pb(
|
||||||
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
# )
|
# )
|
||||||
@ -72,7 +81,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||||
)
|
)
|
||||||
|
with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}):
|
||||||
|
|
||||||
|
with self.profiler.record_event("internal", "generate_token"):
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
@ -82,6 +93,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def Decode(self, request, context):
|
async def Decode(self, request, context):
|
||||||
|
batch0 = self.cache.cache[request.batches[0].id]
|
||||||
|
with self.profiler.record_event("external",
|
||||||
|
"decode",
|
||||||
|
{"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.input_ids.size(0)},
|
||||||
|
{"util": len(batch0.requests)}):
|
||||||
if len(request.batches) == 0:
|
if len(request.batches) == 0:
|
||||||
raise ValueError("Must provide at least one batch")
|
raise ValueError("Must provide at least one batch")
|
||||||
|
|
||||||
@ -96,10 +112,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
raise ValueError("All batches are empty")
|
raise ValueError("All batches are empty")
|
||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
|
with self.profiler.record_event("internal", "concatenate"):
|
||||||
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi)
|
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi)
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
|
with self.profiler.record_event("internal", "generate_token"):
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user