diff --git a/server/Makefile b/server/Makefile index a4ce6d8b..ddd98692 100644 --- a/server/Makefile +++ b/server/Makefile @@ -24,7 +24,7 @@ install: gen-server install-torch pip install -e ".[bnb, accelerate]" run-dev: - SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded + SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve meta-llama/Llama-2-7b-hf --adapter-id arnavgrg/codealpaca-qlora --sharded export-requirements: poetry export -o requirements.txt -E bnb -E quantize --without-hashes diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e74c0331..08446ac1 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -24,6 +24,7 @@ class Dtype(str, Enum): @app.command() def serve( model_id: str, + adapter_id: str = "", revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, @@ -76,7 +77,7 @@ def serve( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e9260eed..10d9ebf7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -68,12 +68,19 @@ if FLASH_ATTENTION: def get_model( model_id: str, + adapter_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str], dtype: Optional[str], trust_remote_code: bool, ) -> Model: + if len(adapter_id) > 0: + logger.warning( + "adapter_id is only supported for FlashLlama models and will be " + "ignored for other models." + ) + if dtype is None: dtype = torch.float16 elif dtype == "float16": @@ -184,6 +191,7 @@ def get_model( if FLASH_ATTENTION: return FlashLlama( model_id, + adapter_id, revision, quantize=quantize, dtype=dtype, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 96fb0c26..01d1ca6a 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -23,11 +23,21 @@ class FlashLlama(FlashCausalLM): def __init__( self, model_id: str, + adapter_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + print("ASDFASDF FLASHLLAMA INIT") + print("Args:") + print(f"model_id: {model_id}") + print(f"adapter_id: {adapter_id}") + print(f"revision: {revision}") + print(f"quantize: {quantize}") + print(f"dtype: {dtype}") + print(f"trust_remote_code: {trust_remote_code}") + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 1cedc151..53a602b1 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -106,6 +106,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_id: str, + adapter_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -115,6 +116,7 @@ def serve( ): async def serve_inner( model_id: str, + adapter_id: str, revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, @@ -134,7 +136,7 @@ def serve( try: model = get_model( - model_id, revision, sharded, quantize, dtype, trust_remote_code + model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code ) except Exception: logger.exception("Error when initializing model") @@ -182,5 +184,5 @@ def serve( await server.stop(0) asyncio.run( - serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + serve_inner(model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code) ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 0330402d..d4f5f0b6 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -16,6 +16,9 @@ class Weights: process_group, aliases: Optional[Dict[str, List[str]]] = None, ): + # idea: maybe we can pass in adapter filenames here and have these take + # precedence over the model filenames? If so, then self.routing would + # just handle the mapping of tensor names to filenames. routing = {} for filename in filenames: with safe_open(filename, framework="pytorch") as f: