run server run-dev and plumbs through adapter-id

This commit is contained in:
Geoffrey Angus 2023-08-15 13:43:00 -07:00
parent 9f18f4c006
commit 034e39185f
6 changed files with 28 additions and 4 deletions

View File

@ -24,7 +24,7 @@ install: gen-server install-torch
pip install -e ".[bnb, accelerate]" pip install -e ".[bnb, accelerate]"
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve meta-llama/Llama-2-7b-hf --adapter-id arnavgrg/codealpaca-qlora --sharded
export-requirements: export-requirements:
poetry export -o requirements.txt -E bnb -E quantize --without-hashes poetry export -o requirements.txt -E bnb -E quantize --without-hashes

View File

@ -24,6 +24,7 @@ class Dtype(str, Enum):
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
adapter_id: str = "",
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
@ -76,7 +77,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
server.serve( server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
) )

View File

@ -68,12 +68,19 @@ if FLASH_ATTENTION:
def get_model( def get_model(
model_id: str, model_id: str,
adapter_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if len(adapter_id) > 0:
logger.warning(
"adapter_id is only supported for FlashLlama models and will be "
"ignored for other models."
)
if dtype is None: if dtype is None:
dtype = torch.float16 dtype = torch.float16
elif dtype == "float16": elif dtype == "float16":
@ -184,6 +191,7 @@ def get_model(
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashLlama(
model_id, model_id,
adapter_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,

View File

@ -23,11 +23,21 @@ class FlashLlama(FlashCausalLM):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
adapter_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
print("ASDFASDF FLASHLLAMA INIT")
print("Args:")
print(f"model_id: {model_id}")
print(f"adapter_id: {adapter_id}")
print(f"revision: {revision}")
print(f"quantize: {quantize}")
print(f"dtype: {dtype}")
print(f"trust_remote_code: {trust_remote_code}")
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -106,6 +106,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve( def serve(
model_id: str, model_id: str,
adapter_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
@ -115,6 +116,7 @@ def serve(
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
adapter_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
@ -134,7 +136,7 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
@ -182,5 +184,5 @@ def serve(
await server.stop(0) await server.stop(0)
asyncio.run( asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) serve_inner(model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code)
) )

View File

@ -16,6 +16,9 @@ class Weights:
process_group, process_group,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
): ):
# idea: maybe we can pass in adapter filenames here and have these take
# precedence over the model filenames? If so, then self.routing would
# just handle the mapping of tensor names to filenames.
routing = {} routing = {}
for filename in filenames: for filename in filenames:
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f: