Overhead reduction (#58) (#85)

Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com>
This commit is contained in:
Karol Damaszke 2024-02-29 09:17:45 +01:00 committed by GitHub
parent 212136dff8
commit 022ce1eaaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 65 additions and 172 deletions

View File

@ -957,9 +957,12 @@ class CausalLM(Model):
new_input_length = input_length + 1 new_input_length = input_length + 1
# Generated token # Generated token
next_token_text, prefix_offset, read_offset = self.decode_token( if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0:
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset next_token_text = ''
) else:
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
@ -975,9 +978,12 @@ class CausalLM(Model):
if i % self.world_size == self.rank: if i % self.world_size == self.rank:
if stop: if stop:
# Decode generated tokens # Decode generated tokens
output_text = self.decode( if is_tokenizer_transparent(self.tokenizer):
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] output_text = None
) else:
output_text = self.decode(
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
)
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, output_text,
stopping_criteria.current_tokens, stopping_criteria.current_tokens,
@ -1034,7 +1040,8 @@ class CausalLM(Model):
req.input_length = new_input_length req.input_length = new_input_length
req.prefix_offset = prefix_offset req.prefix_offset = prefix_offset
req.read_offset = read_offset req.read_offset = read_offset
htorch.core.mark_step()
htorch.core.mark_step()
self.step = self.step + 1 self.step = self.step + 1
if self.hb_profiler is not None: if self.hb_profiler is not None:
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:

View File

@ -1,94 +0,0 @@
import os
import threading
import queue
from contextlib import contextmanager
import time
import json
class FileWriter(threading.Thread):
def __init__(self, filename, write_queue):
super().__init__()
self.filename = filename
self.write_queue = write_queue
self.daemon = True
self.timer_event = threading.Event()
def _drain_write_queue(self):
content = ""
while True:
try:
element = self.write_queue.get_nowait()
content += element
except queue.Empty:
break
return content
def run(self):
# don't check the queue too often
while not self.timer_event.wait(1):
# Block and wait for the next item in the queue
content = self.write_queue.get()
# Collect any other items in the queue
content += self._drain_write_queue()
with open(self.filename, "a") as outfile:
outfile.write(content)
class Profiler():
profiling_trace_events = queue.Queue()
event_tid = {"counter": 1, "external": 2, "internal": 3, "own": 4}
filename = "server_events.json"
def __init__(self):
self.step = 0
self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0
if self.enabled:
# initialize the trace file
with open(self.filename, "w") as outfile:
outfile.write('{"traceEvents": ')
file_writer = FileWriter(self.filename, self.profiling_trace_events)
file_writer.start()
@contextmanager
def record_event(self, type, name, args=None, util=None, count_step=False):
if self.enabled:
start = time.time() * 1000000.0
if util is not None:
self._add_util_event(util, start)
if count_step:
if args is None:
args = {}
args["step"] = self.step
self.step += 1
event = {
"pid": 1,
"tid": self.event_tid[type],
"ph": "X",
"name": name,
"ts": start,
"dur": None,
"args": args
}
yield
end = time.time() * 1000000.0
event["dur"] = end - start
self.profiling_trace_events.put(json.dumps([event]))
else:
yield
def _add_util_event(self, util, start):
util_event = {
"pid": 1,
"tid": self.event_tid["counter"],
"ph": "C",
"name": "util",
"ts": start,
"args": {
"util": util["util"],
}
}
self.profiling_trace_events.put(json.dumps([util_event]))

View File

@ -16,21 +16,18 @@ from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from .profiler import Profiler
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]): def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.profiler = Profiler() self.cache = cache
with self.profiler.record_event("external", "init"): self.model = model
self.cache = cache self.server_urls = server_urls
self.model = model # For some reason, inference_mode does not work well with GLOO which we use on CPU
self.server_urls = server_urls # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
# For some reason, inference_mode does not work well with GLOO which we use on CPU # op not optimized issue. Will investigate further.
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul # if model.device.type == "hpu":
# op not optimized issue. Will investigate further. # Force inference mode for the lifetime of TextGenerationService
# if model.device.type == "hpu": # self._inference_mode_raii_guard = torch._C._InferenceMode(True)
# Force inference mode for the lifetime of TextGenerationService
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context): async def Info(self, request, context):
return self.model.info return self.model.info
@ -44,27 +41,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
async def ClearCache(self, request, context): async def ClearCache(self, request, context):
with self.profiler.record_event("external", "clear_cache"): if request.HasField("id"):
if request.HasField("id"): self.cache.delete(request.id)
self.cache.delete(request.id) else:
else: self.cache.clear()
self.cache.clear() return generate_pb2.ClearCacheResponse()
return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context): async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id) batch = self.cache.pop(request.batch_id)
with self.profiler.record_event( if batch is None:
type="external", raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
name="filter_batch", filtered_batch = batch.filter(request.request_ids)
args={"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]}, self.cache.set(filtered_batch)
util={"util": len(batch.requests)}
):
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
def batch_from_pb(batch): def batch_from_pb(batch):
@ -72,59 +62,44 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
) )
with self.profiler.record_event("external", "warmup"): batches = [batch_from_pb(batch) for batch in request.batches]
batches = [batch_from_pb(batch) for batch in request.batches] self.model.warmup(batches)
self.model.warmup(batches)
return generate_pb2.WarmupResponse() return generate_pb2.WarmupResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
) )
with self.profiler.record_event( generations, next_batch = self.model.generate_token([batch])
type="external", self.cache.set(next_batch)
name="prefill",
args={"batch_size": batch.batch_size, "sequence_length": batch.seq_length}
):
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
generations, next_batch = self.model.generate_token([batch])
self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )
async def Decode(self, request, context): async def Decode(self, request, context):
batch0 = self.cache.cache[request.batches[0].id] if len(request.batches) == 0:
with self.profiler.record_event( raise ValueError("Must provide at least one batch")
type="external",
name="decode",
args={"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.batch_size},
util={"util": len(batch0.requests)}
):
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
batches = [] batches = []
for batch_pb in request.batches: for batch_pb in request.batches:
batch = self.cache.pop(batch_pb.id) batch = self.cache.pop(batch_pb.id)
if batch is None: if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch) batches.append(batch)
if len(batches) == 0: if len(batches) == 0:
raise ValueError("All batches are empty") raise ValueError("All batches are empty")
with self.profiler.record_event(type="internal", name="generate_token", count_step=True): generations, next_batch = self.model.generate_token(batches)
generations, next_batch = self.model.generate_token(batches) self.cache.set(next_batch)
self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )
def serve( def serve(

View File

@ -212,8 +212,13 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores) scores = warper(input_ids, scores)
next_ids = self.choice(scores) next_ids = self.choice(scores)
logprobs = torch.log_softmax(scores, -1) # ignore logprobs if we use greedy search
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) if type(self.choice) == Greedy:
logprobs = torch.zeros_like(scores, device="cpu")
next_logprobs = torch.zeros_like(next_ids.view(-1), device="cpu")
else:
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs, logprobs return next_ids, next_logprobs, logprobs