text-generation-inference/server/text_generation_server/server.py
fxmarty 166dc0b87d MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.

Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
3489ce7936
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1

By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.

Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```

---------

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-07-17 05:36:58 +00:00

227 lines
7.4 KiB
Python

# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import asyncio
import os
import sys
import torch
import time
import signal
from grpc import aio
from loguru import logger
from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id
class SignalHandler:
KEEP_PROCESSING = True
def __init__(self):
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(
self,
model: Model,
cache: Cache,
server_urls: List[str],
):
self.cache = cache
self.model = model
self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
# op not optimized issue. Will investigate further.
# if model.device.type == "hpu":
# Force inference mode for the lifetime of TextGenerationService
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context):
return self.model.info
async def Health(self, request, context):
if self.model.device.type == "hpu":
torch.zeros((2, 2)).to("hpu")
return generate_pb2.HealthResponse()
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
async def ClearCache(self, request, context):
if request.HasField("id"):
self.cache.delete(request.id)
else:
self.cache.clear()
return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id)
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
def batch_from_pb(batch):
return self.model.batch_type.from_pb(
batch, self.model.tokenizer, self.model.dtype, self.model.device
)
batches = [batch_from_pb(batch) for batch in request.batches]
self.model.warmup(batches)
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context):
start = time.time_ns()
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch, timings = self.model.generate_token([batch])
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
)
async def Decode(self, request, context):
start = time.time_ns()
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
batches = []
for batch_pb in request.batches:
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
if len(batches) == 0:
raise ValueError("All batches are empty")
generations, next_batch, timings = self.model.generate_token(batches)
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
concat_ns=None, # TODO: measure concat time
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
)
def serve(
model_id: str,
revision: Optional[str],
sharded: bool,
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level="INFO",
serialize=False,
backtrace=True,
diagnose=False,
)
async def serve_inner(
model_id: str,
revision: Optional[str],
sharded: bool = False,
speculate: Optional[int] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
logger.info("Server:server_inner: sharded ={}".format(sharded))
if sharded:
rank = int(os.environ["RANK"])
logger.info("Server:server_inner: rank ={}".format(rank))
server_urls = [
unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"]))
]
local_url = server_urls[int(os.environ["RANK"])]
else:
local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url]
logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url))
if dtype == "bfloat16" or None:
data_type = torch.bfloat16
else:
data_type = torch.float
if revision == "None":
revision = None
try:
model = get_model(
model_id,
revision,
speculate,
data_type,
trust_remote_code
)
except Exception:
logger.exception("Error when initializing model")
raise
server = aio.server(
interceptors=[
ExceptionInterceptor(),
UDSOpenTelemetryAioServerInterceptor(),
]
)
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls), server
)
SERVICE_NAMES = (
generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port(local_url)
await server.start()
logger.info("Server started at {}".format(local_url))
signal_handler = SignalHandler()
while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5)
set_model_id(model_id)
asyncio.run(
serve_inner(
model_id, revision, sharded, speculate, dtype, trust_remote_code
)
)