From a6057c407650bfdb026b2712b460edc293183447 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 24 Jul 2023 10:41:24 +0200 Subject: [PATCH] fix(server): fix exllama buffers --- server/text_generation_server/server.py | 30 +++++++++---------- .../utils/gptq/exllama.py | 8 +++++ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 0929b46f..ab99ac55 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -105,21 +105,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( - model_id: str, - revision: Optional[str], - sharded: bool, - quantize: Optional[str], - dtype: Optional[str], - trust_remote_code: bool, - uds_path: Path, -): - async def serve_inner( model_id: str, revision: Optional[str], - sharded: bool = False, - quantize: Optional[str] = None, - dtype: Optional[str] = None, - trust_remote_code: bool = False, + sharded: bool, + quantize: Optional[str], + dtype: Optional[str], + trust_remote_code: bool, + uds_path: Path, +): + async def serve_inner( + model_id: str, + revision: Optional[str], + sharded: bool = False, + quantize: Optional[str] = None, + dtype: Optional[str] = None, + trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" if sharded: @@ -146,9 +146,9 @@ def serve( # For which we have the finale shapes only after the model has loaded # This will allocate those buffers. from text_generation_server.utils.gptq.exllama import ( - create_exllama_buffers, + create_exllama_buffers, set_device ) - + set_device(model.device) create_exllama_buffers() except ImportError: pass diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/utils/gptq/exllama.py index e89b725c..64941409 100644 --- a/server/text_generation_server/utils/gptq/exllama.py +++ b/server/text_generation_server/utils/gptq/exllama.py @@ -12,6 +12,7 @@ def ext_make_q4(qweight, qzeros, scales, g_idx, device): ) + def ext_q4_matmul(x, q4, q4_width): """Matrix multiplication, returns x @ q4""" outshape = x.shape[:-1] + (q4_width,) @@ -32,9 +33,16 @@ TEMP_STATE = None TEMP_DQ = None +def set_device(device): + global DEVICE + DEVICE = device + + def create_exllama_buffers(): global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ + assert DEVICE is not None, "call set_device first" + if ACT_ORDER: # TODO: this should be set to rust side `max_total_tokens`, but TGI # does not offer an API to expose this variable to python, as this variable