mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com>
This commit is contained in:
parent
212136dff8
commit
022ce1eaaf
@ -957,9 +957,12 @@ class CausalLM(Model):
|
|||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0:
|
||||||
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
next_token_text = ''
|
||||||
)
|
else:
|
||||||
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
|
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
@ -975,9 +978,12 @@ class CausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
if is_tokenizer_transparent(self.tokenizer):
|
||||||
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
|
output_text = None
|
||||||
)
|
else:
|
||||||
|
output_text = self.decode(
|
||||||
|
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
|
||||||
|
)
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
stopping_criteria.current_tokens,
|
stopping_criteria.current_tokens,
|
||||||
@ -1034,7 +1040,8 @@ class CausalLM(Model):
|
|||||||
req.input_length = new_input_length
|
req.input_length = new_input_length
|
||||||
req.prefix_offset = prefix_offset
|
req.prefix_offset = prefix_offset
|
||||||
req.read_offset = read_offset
|
req.read_offset = read_offset
|
||||||
htorch.core.mark_step()
|
|
||||||
|
htorch.core.mark_step()
|
||||||
self.step = self.step + 1
|
self.step = self.step + 1
|
||||||
if self.hb_profiler is not None:
|
if self.hb_profiler is not None:
|
||||||
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:
|
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:
|
||||||
|
@ -1,94 +0,0 @@
|
|||||||
import os
|
|
||||||
import threading
|
|
||||||
import queue
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
class FileWriter(threading.Thread):
|
|
||||||
def __init__(self, filename, write_queue):
|
|
||||||
super().__init__()
|
|
||||||
self.filename = filename
|
|
||||||
self.write_queue = write_queue
|
|
||||||
self.daemon = True
|
|
||||||
self.timer_event = threading.Event()
|
|
||||||
|
|
||||||
def _drain_write_queue(self):
|
|
||||||
content = ""
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
element = self.write_queue.get_nowait()
|
|
||||||
content += element
|
|
||||||
except queue.Empty:
|
|
||||||
break
|
|
||||||
return content
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
# don't check the queue too often
|
|
||||||
while not self.timer_event.wait(1):
|
|
||||||
# Block and wait for the next item in the queue
|
|
||||||
content = self.write_queue.get()
|
|
||||||
# Collect any other items in the queue
|
|
||||||
content += self._drain_write_queue()
|
|
||||||
|
|
||||||
with open(self.filename, "a") as outfile:
|
|
||||||
outfile.write(content)
|
|
||||||
|
|
||||||
class Profiler():
|
|
||||||
profiling_trace_events = queue.Queue()
|
|
||||||
event_tid = {"counter": 1, "external": 2, "internal": 3, "own": 4}
|
|
||||||
filename = "server_events.json"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.step = 0
|
|
||||||
self.enabled = os.getenv("TGI_PROFILER_ENABLED", "false").lower() == "true" and int(os.getenv("RANK", "0")) == 0
|
|
||||||
if self.enabled:
|
|
||||||
# initialize the trace file
|
|
||||||
with open(self.filename, "w") as outfile:
|
|
||||||
outfile.write('{"traceEvents": ')
|
|
||||||
file_writer = FileWriter(self.filename, self.profiling_trace_events)
|
|
||||||
file_writer.start()
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def record_event(self, type, name, args=None, util=None, count_step=False):
|
|
||||||
if self.enabled:
|
|
||||||
start = time.time() * 1000000.0
|
|
||||||
if util is not None:
|
|
||||||
self._add_util_event(util, start)
|
|
||||||
|
|
||||||
if count_step:
|
|
||||||
if args is None:
|
|
||||||
args = {}
|
|
||||||
args["step"] = self.step
|
|
||||||
self.step += 1
|
|
||||||
event = {
|
|
||||||
"pid": 1,
|
|
||||||
"tid": self.event_tid[type],
|
|
||||||
"ph": "X",
|
|
||||||
"name": name,
|
|
||||||
"ts": start,
|
|
||||||
"dur": None,
|
|
||||||
"args": args
|
|
||||||
}
|
|
||||||
yield
|
|
||||||
|
|
||||||
end = time.time() * 1000000.0
|
|
||||||
event["dur"] = end - start
|
|
||||||
|
|
||||||
self.profiling_trace_events.put(json.dumps([event]))
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|
||||||
def _add_util_event(self, util, start):
|
|
||||||
util_event = {
|
|
||||||
"pid": 1,
|
|
||||||
"tid": self.event_tid["counter"],
|
|
||||||
"ph": "C",
|
|
||||||
"name": "util",
|
|
||||||
"ts": start,
|
|
||||||
"args": {
|
|
||||||
"util": util["util"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.profiling_trace_events.put(json.dumps([util_event]))
|
|
@ -16,21 +16,18 @@ from text_generation_server.models import Model, get_model
|
|||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
|
|
||||||
from .profiler import Profiler
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||||
self.profiler = Profiler()
|
self.cache = cache
|
||||||
with self.profiler.record_event("external", "init"):
|
self.model = model
|
||||||
self.cache = cache
|
self.server_urls = server_urls
|
||||||
self.model = model
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
self.server_urls = server_urls
|
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# op not optimized issue. Will investigate further.
|
||||||
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
|
# if model.device.type == "hpu":
|
||||||
# op not optimized issue. Will investigate further.
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
# if model.device.type == "hpu":
|
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
# Force inference mode for the lifetime of TextGenerationService
|
|
||||||
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
|
||||||
|
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
@ -44,27 +41,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
|
|
||||||
async def ClearCache(self, request, context):
|
async def ClearCache(self, request, context):
|
||||||
with self.profiler.record_event("external", "clear_cache"):
|
if request.HasField("id"):
|
||||||
if request.HasField("id"):
|
self.cache.delete(request.id)
|
||||||
self.cache.delete(request.id)
|
else:
|
||||||
else:
|
self.cache.clear()
|
||||||
self.cache.clear()
|
return generate_pb2.ClearCacheResponse()
|
||||||
return generate_pb2.ClearCacheResponse()
|
|
||||||
|
|
||||||
async def FilterBatch(self, request, context):
|
async def FilterBatch(self, request, context):
|
||||||
batch = self.cache.pop(request.batch_id)
|
batch = self.cache.pop(request.batch_id)
|
||||||
with self.profiler.record_event(
|
if batch is None:
|
||||||
type="external",
|
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||||
name="filter_batch",
|
filtered_batch = batch.filter(request.request_ids)
|
||||||
args={"batch_id": request.batch_id, "request_ids": [id for id in request.request_ids]},
|
self.cache.set(filtered_batch)
|
||||||
util={"util": len(batch.requests)}
|
|
||||||
):
|
|
||||||
if batch is None:
|
|
||||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
|
||||||
filtered_batch = batch.filter(request.request_ids)
|
|
||||||
self.cache.set(filtered_batch)
|
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
def batch_from_pb(batch):
|
def batch_from_pb(batch):
|
||||||
@ -72,59 +62,44 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.profiler.record_event("external", "warmup"):
|
batches = [batch_from_pb(batch) for batch in request.batches]
|
||||||
batches = [batch_from_pb(batch) for batch in request.batches]
|
self.model.warmup(batches)
|
||||||
self.model.warmup(batches)
|
|
||||||
|
|
||||||
return generate_pb2.WarmupResponse()
|
return generate_pb2.WarmupResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||||
)
|
)
|
||||||
with self.profiler.record_event(
|
generations, next_batch = self.model.generate_token([batch])
|
||||||
type="external",
|
self.cache.set(next_batch)
|
||||||
name="prefill",
|
|
||||||
args={"batch_size": batch.batch_size, "sequence_length": batch.seq_length}
|
|
||||||
):
|
|
||||||
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
|
|
||||||
generations, next_batch = self.model.generate_token([batch])
|
|
||||||
self.cache.set(next_batch)
|
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Decode(self, request, context):
|
async def Decode(self, request, context):
|
||||||
batch0 = self.cache.cache[request.batches[0].id]
|
if len(request.batches) == 0:
|
||||||
with self.profiler.record_event(
|
raise ValueError("Must provide at least one batch")
|
||||||
type="external",
|
|
||||||
name="decode",
|
|
||||||
args={"request_batches": [batch.id for batch in request.batches], "batch_size": batch0.batch_size},
|
|
||||||
util={"util": len(batch0.requests)}
|
|
||||||
):
|
|
||||||
if len(request.batches) == 0:
|
|
||||||
raise ValueError("Must provide at least one batch")
|
|
||||||
|
|
||||||
batches = []
|
batches = []
|
||||||
for batch_pb in request.batches:
|
for batch_pb in request.batches:
|
||||||
batch = self.cache.pop(batch_pb.id)
|
batch = self.cache.pop(batch_pb.id)
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
|
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
|
||||||
batches.append(batch)
|
batches.append(batch)
|
||||||
|
|
||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
raise ValueError("All batches are empty")
|
raise ValueError("All batches are empty")
|
||||||
|
|
||||||
with self.profiler.record_event(type="internal", name="generate_token", count_step=True):
|
generations, next_batch = self.model.generate_token(batches)
|
||||||
generations, next_batch = self.model.generate_token(batches)
|
self.cache.set(next_batch)
|
||||||
self.cache.set(next_batch)
|
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
|
@ -212,8 +212,13 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores = warper(input_ids, scores)
|
scores = warper(input_ids, scores)
|
||||||
|
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
# ignore logprobs if we use greedy search
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
if type(self.choice) == Greedy:
|
||||||
|
logprobs = torch.zeros_like(scores, device="cpu")
|
||||||
|
next_logprobs = torch.zeros_like(next_ids.view(-1), device="cpu")
|
||||||
|
else:
|
||||||
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
return next_ids, next_logprobs, logprobs
|
return next_ids, next_logprobs, logprobs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user