High-level server profiler (#13)

This commit is contained in:
Krzysztof Laskowski 2024-01-16 09:57:29 +01:00 committed by GitHub
parent 41c4f4fa41
commit c459c86f88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 54 deletions

View File

@ -79,6 +79,7 @@ Environment Variables Added:
| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
</div>

View File

@ -0,0 +1,83 @@
import os
import threading
import queue
from contextlib import contextmanager
import time
import json
class FileWriter(threading.Thread):
def __init__(self, filename, write_queue):
super().__init__()
self.filename = filename
self.write_queue = write_queue
self.daemon = True
self.timer_event = threading.Event()
def _drain_write_queue(self):
content = ""
while True:
try:
element = self.write_queue.get_nowait()
content += element
except queue.Empty:
break
return content
def run(self):
# don't check the queue too often
while not self.timer_event.wait(1):
# Block and wait for the next item in the queue
content = self.write_queue.get()
# Collect any other items in the queue
content += self._drain_write_queue()
with open(self.filename, "a") as outfile:
outfile.write(content)
class Profiler():
profiling_trace_events = queue.Queue()
event_tid = {"counter": 1, "external": 2, "internal": 3, "own": 4}
filename = "server_events.json"
def __init__(self):
self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0
if self.enabled:
# initialize the trace file
with open(self.filename, "w") as outfile:
outfile.write('{"traceEvents": ')
file_writer = FileWriter(self.filename, self.profiling_trace_events)
file_writer.start()
@contextmanager
def record_event(self, type, name, args=None, util=None):
if self.enabled:
start = time.time() * 1000000.0
if util is not None:
self.profiling_trace_events.put(json.dumps([{
"pid": 1,
"tid": self.event_tid["counter"],
"ph": "C",
"name": "util",
"ts": start,
"args": {
"util": util["util"],
}}]))
event = {
"pid": 1,
"tid": self.event_tid[type],
"ph": "X",
"name": name,
"ts": start,
"dur": None,
"args": args
}
yield
end = time.time() * 1000000.0
event["dur"] = end - start
self.profiling_trace_events.put(json.dumps([event]))
else:
yield

View File

@ -16,18 +16,21 @@ from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from .profiler import Profiler
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache
self.model = model
self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
# op not optimized issue. Will investigate further.
# if model.device.type == "hpu":
# Force inference mode for the lifetime of TextGenerationService
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
self.profiler = Profiler()
with self.profiler.record_event("external", "init"):
self.cache = cache
self.model = model
self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
# op not optimized issue. Will investigate further.
# if model.device.type == "hpu":
# Force inference mode for the lifetime of TextGenerationService
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context):
return self.model.info
@ -41,72 +44,87 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
async def ClearCache(self, request, context):
if request.HasField("id"):
self.cache.delete(request.id)
else:
self.cache.clear()
return generate_pb2.ClearCacheResponse()
with self.profiler.record_event("external", "clear_cache"):
if request.HasField("id"):
self.cache.delete(request.id)
else:
self.cache.clear()
return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id)
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi)
self.cache.set(filtered_batch)
with self.profiler.record_event("external",
"filter_batch",
{"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
{"util": len(batch.requests)}):
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
# batch = self.model.batch_type.from_pb(
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
# )
# max_supported_total_tokens = self.model.warmup(batch)
with self.profiler.record_event("external", "warmup"):
# batch = self.model.batch_type.from_pb(
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
# )
# max_supported_total_tokens = self.model.warmup(batch)
# return generate_pb2.WarmupResponse(
# max_supported_total_tokens=max_supported_total_tokens
# )
logger.warning("Warmup is not enabled on HPU.")
return generate_pb2.WarmupResponse()
# return generate_pb2.WarmupResponse(
# max_supported_total_tokens=max_supported_total_tokens
# )
logger.warning("Warmup is not enabled on HPU.")
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
)
with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}):
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
with self.profiler.record_event("internal", "generate_token"):
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
)
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
)
async def Decode(self, request, context):
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
batch0 = self.cache.cache[request.batches[0].id]
with self.profiler.record_event("external",
"decode",
{"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.input_ids.size(0)},
{"util": len(batch0.requests)}):
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
batches = []
for batch_pb in request.batches:
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
batches = []
for batch_pb in request.batches:
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
if len(batches) == 0:
raise ValueError("All batches are empty")
if len(batches) == 0:
raise ValueError("All batches are empty")
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi)
else:
batch = batches[0]
if len(batches) > 1:
with self.profiler.record_event("internal", "concatenate"):
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi)
else:
batch = batches[0]
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
with self.profiler.record_event("internal", "generate_token"):
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
)
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
)
def serve(