mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
run server run-dev and plumbs through adapter-id
This commit is contained in:
parent
9f18f4c006
commit
034e39185f
@ -24,7 +24,7 @@ install: gen-server install-torch
|
|||||||
pip install -e ".[bnb, accelerate]"
|
pip install -e ".[bnb, accelerate]"
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve meta-llama/Llama-2-7b-hf --adapter-id arnavgrg/codealpaca-qlora --sharded
|
||||||
|
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
||||||
|
@ -24,6 +24,7 @@ class Dtype(str, Enum):
|
|||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
adapter_id: str = "",
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
@ -76,7 +77,7 @@ def serve(
|
|||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,12 +68,19 @@ if FLASH_ATTENTION:
|
|||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
adapter_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
if len(adapter_id) > 0:
|
||||||
|
logger.warning(
|
||||||
|
"adapter_id is only supported for FlashLlama models and will be "
|
||||||
|
"ignored for other models."
|
||||||
|
)
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
elif dtype == "float16":
|
elif dtype == "float16":
|
||||||
@ -184,6 +191,7 @@ def get_model(
|
|||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashLlama(
|
||||||
model_id,
|
model_id,
|
||||||
|
adapter_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -23,11 +23,21 @@ class FlashLlama(FlashCausalLM):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
adapter_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
print("ASDFASDF FLASHLLAMA INIT")
|
||||||
|
print("Args:")
|
||||||
|
print(f"model_id: {model_id}")
|
||||||
|
print(f"adapter_id: {adapter_id}")
|
||||||
|
print(f"revision: {revision}")
|
||||||
|
print(f"quantize: {quantize}")
|
||||||
|
print(f"dtype: {dtype}")
|
||||||
|
print(f"trust_remote_code: {trust_remote_code}")
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -106,6 +106,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
adapter_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
@ -115,6 +116,7 @@ def serve(
|
|||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
adapter_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
@ -134,7 +136,7 @@ def serve(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
@ -182,5 +184,5 @@ def serve(
|
|||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
serve_inner(model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||||
)
|
)
|
||||||
|
@ -16,6 +16,9 @@ class Weights:
|
|||||||
process_group,
|
process_group,
|
||||||
aliases: Optional[Dict[str, List[str]]] = None,
|
aliases: Optional[Dict[str, List[str]]] = None,
|
||||||
):
|
):
|
||||||
|
# idea: maybe we can pass in adapter filenames here and have these take
|
||||||
|
# precedence over the model filenames? If so, then self.routing would
|
||||||
|
# just handle the mapping of tensor names to filenames.
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user