mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix(server): fix exllama buffers
This commit is contained in:
parent
73a4d65d26
commit
a6057c4076
@ -105,21 +105,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
|
||||
def serve(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
@ -146,9 +146,9 @@ def serve(
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
# This will allocate those buffers.
|
||||
from text_generation_server.utils.gptq.exllama import (
|
||||
create_exllama_buffers,
|
||||
create_exllama_buffers, set_device
|
||||
)
|
||||
|
||||
set_device(model.device)
|
||||
create_exllama_buffers()
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -12,6 +12,7 @@ def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||
)
|
||||
|
||||
|
||||
|
||||
def ext_q4_matmul(x, q4, q4_width):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
outshape = x.shape[:-1] + (q4_width,)
|
||||
@ -32,9 +33,16 @@ TEMP_STATE = None
|
||||
TEMP_DQ = None
|
||||
|
||||
|
||||
def set_device(device):
|
||||
global DEVICE
|
||||
DEVICE = device
|
||||
|
||||
|
||||
def create_exllama_buffers():
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
||||
|
||||
assert DEVICE is not None, "call set_device first"
|
||||
|
||||
if ACT_ORDER:
|
||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# does not offer an API to expose this variable to python, as this variable
|
||||
|
Loading…
Reference in New Issue
Block a user