mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
run server run-dev and plumbs through adapter-id
This commit is contained in:
parent
9f18f4c006
commit
034e39185f
@ -24,7 +24,7 @@ install: gen-server install-torch
|
||||
pip install -e ".[bnb, accelerate]"
|
||||
|
||||
run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve meta-llama/Llama-2-7b-hf --adapter-id arnavgrg/codealpaca-qlora --sharded
|
||||
|
||||
export-requirements:
|
||||
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
||||
|
@ -24,6 +24,7 @@ class Dtype(str, Enum):
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
adapter_id: str = "",
|
||||
revision: Optional[str] = None,
|
||||
sharded: bool = False,
|
||||
quantize: Optional[Quantization] = None,
|
||||
@ -76,7 +77,7 @@ def serve(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
server.serve(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||
model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||
)
|
||||
|
||||
|
||||
|
@ -68,12 +68,19 @@ if FLASH_ATTENTION:
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
if len(adapter_id) > 0:
|
||||
logger.warning(
|
||||
"adapter_id is only supported for FlashLlama models and will be "
|
||||
"ignored for other models."
|
||||
)
|
||||
|
||||
if dtype is None:
|
||||
dtype = torch.float16
|
||||
elif dtype == "float16":
|
||||
@ -184,6 +191,7 @@ def get_model(
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlama(
|
||||
model_id,
|
||||
adapter_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
|
@ -23,11 +23,21 @@ class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
print("ASDFASDF FLASHLLAMA INIT")
|
||||
print("Args:")
|
||||
print(f"model_id: {model_id}")
|
||||
print(f"adapter_id: {adapter_id}")
|
||||
print(f"revision: {revision}")
|
||||
print(f"quantize: {quantize}")
|
||||
print(f"dtype: {dtype}")
|
||||
print(f"trust_remote_code: {trust_remote_code}")
|
||||
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
@ -106,6 +106,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
def serve(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
@ -115,6 +116,7 @@ def serve(
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
@ -134,7 +136,7 @@ def serve(
|
||||
|
||||
try:
|
||||
model = get_model(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||
model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
@ -182,5 +184,5 @@ def serve(
|
||||
await server.stop(0)
|
||||
|
||||
asyncio.run(
|
||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||
serve_inner(model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||
)
|
||||
|
@ -16,6 +16,9 @@ class Weights:
|
||||
process_group,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
# idea: maybe we can pass in adapter filenames here and have these take
|
||||
# precedence over the model filenames? If so, then self.routing would
|
||||
# just handle the mapping of tensor names to filenames.
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
|
Loading…
Reference in New Issue
Block a user