fix(server): fix exllama buffers

This commit is contained in:
OlivierDehaene 2023-07-24 10:41:24 +02:00
parent 73a4d65d26
commit a6057c4076
2 changed files with 23 additions and 15 deletions

View File

@ -105,21 +105,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
async def serve_inner(
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
async def serve_inner(
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
if sharded:
@ -146,9 +146,9 @@ def serve(
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.utils.gptq.exllama import (
create_exllama_buffers,
create_exllama_buffers, set_device
)
set_device(model.device)
create_exllama_buffers()
except ImportError:
pass

View File

@ -12,6 +12,7 @@ def ext_make_q4(qweight, qzeros, scales, g_idx, device):
)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
@ -32,9 +33,16 @@ TEMP_STATE = None
TEMP_DQ = None
def set_device(device):
global DEVICE
DEVICE = device
def create_exllama_buffers():
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
assert DEVICE is not None, "call set_device first"
if ACT_ORDER:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable