diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 17364e829..41129da05 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -957,9 +957,12 @@ class CausalLM(Model): new_input_length = input_length + 1 # Generated token - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) + if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: + next_token_text = '' + else: + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[0:new_input_length, 0], prefix_offset, read_offset + ) # Evaluate stopping criteria stop, reason = stopping_criteria( @@ -975,9 +978,12 @@ class CausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] - ) + if is_tokenizer_transparent(self.tokenizer): + output_text = None + else: + output_text = self.decode( + all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] + ) generated_text = GeneratedText( output_text, stopping_criteria.current_tokens, @@ -1034,7 +1040,8 @@ class CausalLM(Model): req.input_length = new_input_length req.prefix_offset = prefix_offset req.read_offset = read_offset - htorch.core.mark_step() + + htorch.core.mark_step() self.step = self.step + 1 if self.hb_profiler is not None: if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: diff --git a/server/text_generation_server/profiler.py b/server/text_generation_server/profiler.py deleted file mode 100644 index e0248932d..000000000 --- a/server/text_generation_server/profiler.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -import threading -import queue - -from contextlib import contextmanager -import time -import json - -class FileWriter(threading.Thread): - def __init__(self, filename, write_queue): - super().__init__() - self.filename = filename - self.write_queue = write_queue - self.daemon = True - self.timer_event = threading.Event() - - def _drain_write_queue(self): - content = "" - while True: - try: - element = self.write_queue.get_nowait() - content += element - except queue.Empty: - break - return content - - def run(self): - # don't check the queue too often - while not self.timer_event.wait(1): - # Block and wait for the next item in the queue - content = self.write_queue.get() - # Collect any other items in the queue - content += self._drain_write_queue() - - with open(self.filename, "a") as outfile: - outfile.write(content) - -class Profiler(): - profiling_trace_events = queue.Queue() - event_tid = {"counter": 1, "external": 2, "internal": 3, "own": 4} - filename = "server_events.json" - - def __init__(self): - self.step = 0 - self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0 - if self.enabled: - # initialize the trace file - with open(self.filename, "w") as outfile: - outfile.write('{"traceEvents": ') - file_writer = FileWriter(self.filename, self.profiling_trace_events) - file_writer.start() - - @contextmanager - def record_event(self, type, name, args=None, util=None, count_step=False): - if self.enabled: - start = time.time() * 1000000.0 - if util is not None: - self._add_util_event(util, start) - - if count_step: - if args is None: - args = {} - args["step"] = self.step - self.step += 1 - event = { - "pid": 1, - "tid": self.event_tid[type], - "ph": "X", - "name": name, - "ts": start, - "dur": None, - "args": args - } - yield - - end = time.time() * 1000000.0 - event["dur"] = end - start - - self.profiling_trace_events.put(json.dumps([event])) - else: - yield - - def _add_util_event(self, util, start): - util_event = { - "pid": 1, - "tid": self.event_tid["counter"], - "ph": "C", - "name": "util", - "ts": start, - "args": { - "util": util["util"], - } - } - self.profiling_trace_events.put(json.dumps([util_event])) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 841bda933..32b1b9142 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -16,21 +16,18 @@ from text_generation_server.models import Model, get_model from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from .profiler import Profiler class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__(self, model: Model, cache: Cache, server_urls: List[str]): - self.profiler = Profiler() - with self.profiler.record_event("external", "init"): - self.cache = cache - self.model = model - self.server_urls = server_urls - # For some reason, inference_mode does not work well with GLOO which we use on CPU - # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul - # op not optimized issue. Will investigate further. - # if model.device.type == "hpu": - # Force inference mode for the lifetime of TextGenerationService - # self._inference_mode_raii_guard = torch._C._InferenceMode(True) + self.cache = cache + self.model = model + self.server_urls = server_urls + # For some reason, inference_mode does not work well with GLOO which we use on CPU + # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul + # op not optimized issue. Will investigate further. + # if model.device.type == "hpu": + # Force inference mode for the lifetime of TextGenerationService + # self._inference_mode_raii_guard = torch._C._InferenceMode(True) async def Info(self, request, context): return self.model.info @@ -44,27 +41,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): - with self.profiler.record_event("external", "clear_cache"): - if request.HasField("id"): - self.cache.delete(request.id) - else: - self.cache.clear() - return generate_pb2.ClearCacheResponse() + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() + return generate_pb2.ClearCacheResponse() async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) - with self.profiler.record_event( - type="external", - name="filter_batch", - args={"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]}, - util={"util": len(batch.requests)} - ): - if batch is None: - raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.request_ids) - self.cache.set(filtered_batch) + if batch is None: + raise ValueError(f"Batch ID {request.batch_id} not found in cache.") + filtered_batch = batch.filter(request.request_ids) + self.cache.set(filtered_batch) - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): def batch_from_pb(batch): @@ -72,59 +62,44 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi ) - with self.profiler.record_event("external", "warmup"): - batches = [batch_from_pb(batch) for batch in request.batches] - self.model.warmup(batches) + batches = [batch_from_pb(batch) for batch in request.batches] + self.model.warmup(batches) - return generate_pb2.WarmupResponse() + return generate_pb2.WarmupResponse() async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi ) - with self.profiler.record_event( - type="external", - name="prefill", - args={"batch_size": batch.batch_size, "sequence_length": batch.seq_length} - ): - with self.profiler.record_event(type="internal", name="generate_token", count_step=True): - generations, next_batch = self.model.generate_token([batch]) - self.cache.set(next_batch) + generations, next_batch = self.model.generate_token([batch]) + self.cache.set(next_batch) - return generate_pb2.PrefillResponse( - generations=[generation.to_pb() for generation in generations], - batch=next_batch.to_pb() if next_batch else None, - ) + return generate_pb2.PrefillResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + ) async def Decode(self, request, context): - batch0 = self.cache.cache[request.batches[0].id] - with self.profiler.record_event( - type="external", - name="decode", - args={"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.batch_size}, - util={"util": len(batch0.requests)} - ): - if len(request.batches) == 0: - raise ValueError("Must provide at least one batch") + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") - batches = [] - for batch_pb in request.batches: - batch = self.cache.pop(batch_pb.id) - if batch is None: - raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batches.append(batch) + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) - if len(batches) == 0: - raise ValueError("All batches are empty") + if len(batches) == 0: + raise ValueError("All batches are empty") - with self.profiler.record_event(type="internal", name="generate_token", count_step=True): - generations, next_batch = self.model.generate_token(batches) - self.cache.set(next_batch) + generations, next_batch = self.model.generate_token(batches) + self.cache.set(next_batch) - return generate_pb2.DecodeResponse( - generations=[generation.to_pb() for generation in generations], - batch=next_batch.to_pb() if next_batch else None, - ) + return generate_pb2.DecodeResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + ) def serve( diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 2814ced8f..2535f464c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -212,8 +212,13 @@ class HeterogeneousNextTokenChooser: scores = warper(input_ids, scores) next_ids = self.choice(scores) - logprobs = torch.log_softmax(scores, -1) - next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) + # ignore logprobs if we use greedy search + if type(self.choice) == Greedy: + logprobs = torch.zeros_like(scores, device="cpu") + next_logprobs = torch.zeros_like(next_ids.view(-1), device="cpu") + else: + logprobs = torch.log_softmax(scores, -1) + next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) return next_ids, next_logprobs, logprobs