mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
parent
73a4d65d26
commit
37df6df38e
@ -105,21 +105,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str],
|
|
||||||
sharded: bool,
|
|
||||||
quantize: Optional[str],
|
|
||||||
dtype: Optional[str],
|
|
||||||
trust_remote_code: bool,
|
|
||||||
uds_path: Path,
|
|
||||||
):
|
|
||||||
async def serve_inner(
|
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str],
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool,
|
||||||
|
uds_path: Path,
|
||||||
|
):
|
||||||
|
async def serve_inner(
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str],
|
||||||
|
sharded: bool = False,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
@ -147,8 +147,10 @@ def serve(
|
|||||||
# This will allocate those buffers.
|
# This will allocate those buffers.
|
||||||
from text_generation_server.utils.gptq.exllama import (
|
from text_generation_server.utils.gptq.exllama import (
|
||||||
create_exllama_buffers,
|
create_exllama_buffers,
|
||||||
|
set_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
set_device(model.device)
|
||||||
create_exllama_buffers()
|
create_exllama_buffers()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
@ -32,9 +32,16 @@ TEMP_STATE = None
|
|||||||
TEMP_DQ = None
|
TEMP_DQ = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_device(device):
|
||||||
|
global DEVICE
|
||||||
|
DEVICE = device
|
||||||
|
|
||||||
|
|
||||||
def create_exllama_buffers():
|
def create_exllama_buffers():
|
||||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
||||||
|
|
||||||
|
assert DEVICE is not None, "call set_device first"
|
||||||
|
|
||||||
if ACT_ORDER:
|
if ACT_ORDER:
|
||||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||||
# does not offer an API to expose this variable to python, as this variable
|
# does not offer an API to expose this variable to python, as this variable
|
||||||
|
Loading…
Reference in New Issue
Block a user