From d6d2b9426bd7ab06d1dcf75c503262367c282176 Mon Sep 17 00:00:00 2001 From: Subaandh Krishnakumar Date: Tue, 29 Aug 2023 11:06:57 +0200 Subject: [PATCH] feat: support for loading peft model --- launcher/src/main.rs | 14 ++++ server/text_generation_server/cli.py | 4 +- .../text_generation_server/models/__init__.py | 8 +++ .../models/causal_lm.py | 67 ++++++++++++++----- .../text_generation_server/models/flash_rw.py | 51 +++++++++++--- server/text_generation_server/models/rw.py | 66 +++++++++++++----- server/text_generation_server/server.py | 8 ++- server/text_generation_server/utils/hub.py | 1 + server/text_generation_server/utils/peft.py | 1 - 9 files changed, 172 insertions(+), 48 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index cbb6f25d..a75fffe2 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -100,6 +100,13 @@ struct Args { #[clap(default_value = "bigscience/bloom-560m", long, env)] model_id: String, + /// The name of the peft model to load. + /// Can be a PEFT_MODEL_ID as listed on like + /// Or it can be a local directory containing the necessary files (adapter_model.bin) + /// as saved by `save_pretrained(...)` methods of transformers + #[clap(long, env)] + peft_model_id: Option, + /// The actual revision of the model if you're referring to a model /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. #[clap(long, env)] @@ -347,6 +354,7 @@ enum ShardStatus { #[allow(clippy::too_many_arguments)] fn shard_manager( model_id: String, + peft_model_id: Option, revision: Option, quantize: Option, dtype: Option, @@ -391,6 +399,10 @@ fn shard_manager( "--json-output".to_string(), ]; + if let Some(peft_model_id) = peft_model_id { + shard_args.push("--peft-model-id".to_string()); + shard_args.push(peft_model_id) + } // Activate trust remote code if trust_remote_code { shard_args.push("--trust-remote-code".to_string()); @@ -846,6 +858,7 @@ fn spawn_shards( // Start shard processes for rank in 0..num_shard { let model_id = args.model_id.clone(); + let peft_model_id = args.peft_model_id.clone(); let revision = args.revision.clone(); let uds_path = args.shard_uds_path.clone(); let master_addr = args.master_addr.clone(); @@ -868,6 +881,7 @@ fn spawn_shards( thread::spawn(move || { shard_manager( model_id, + peft_model_id, revision, quantize, dtype, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e3fda07f..62f0dd54 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -27,6 +27,7 @@ class Dtype(str, Enum): @app.command() def serve( model_id: str, + peft_model_id: str = None, revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, @@ -78,8 +79,9 @@ def serve( raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) + logger.info(f"{peft_model_id=}") server.serve( - model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_id ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 932ab32e..68fc98df 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -75,6 +75,7 @@ def get_model( quantize: Optional[str], dtype: Optional[str], trust_remote_code: bool, + peft_model_id: str, ) -> Model: if dtype is None: dtype = torch.float16 @@ -85,6 +86,7 @@ def get_model( else: raise RuntimeError(f"Unknown dtype {dtype}") + logger.info(f"In get models {peft_model_id=}") if "facebook/galactica" in model_id: return GalacticaSharded( model_id, @@ -120,6 +122,7 @@ def get_model( model_id, revision=revision, trust_remote_code=trust_remote_code ) model_type = config_dict["model_type"] + logger.info(f"Model type is {model_type=}") if model_type == "gpt_bigcode": if FLASH_ATTENTION: @@ -217,20 +220,24 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): + logger.error("In Flash RW") return FlashRWSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + peft_model_id=peft_model_id, ) else: + logger.error(f"Value not none {peft_model_id=}") return RW( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + peft_model_id=peft_model_id ) elif model_type == "opt": @@ -273,6 +280,7 @@ def get_model( "4bit quantization is not supported for AutoModel" ) if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + logger.error("In CausalLM") return CausalLM( model_id, revision, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 4e338263..b2b2b54d 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -17,6 +17,8 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from transformers import BitsAndBytesConfig +from peft import PeftModel tracer = trace.get_tracer(__name__) @@ -483,6 +485,7 @@ class CausalLM(Model): quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + peft_model_id: str = None ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -494,23 +497,53 @@ class CausalLM(Model): device = torch.device("cpu") dtype = torch.float32 - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map="auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) + if peft_model_id: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + padding_side="left", + ) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + llm_int8_enable_fp32_cpu_offload=True, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + trust_remote_code=True, + device_map="auto", + quantization_config=bnb_config + ) + + # load Lora weights + model = PeftModel.from_pretrained( + model, + peft_model_id, + device_map="auto", + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + + if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 2fc7c53d..bd6a07e8 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -2,7 +2,7 @@ import torch import torch.distributed from opentelemetry import trace -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForCausalLM from typing import Optional from text_generation_server.models import FlashCausalLM @@ -15,6 +15,8 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from transformers import BitsAndBytesConfig +from peft import PeftModel tracer = trace.get_tracer(__name__) @@ -27,6 +29,7 @@ class FlashRWSharded(FlashCausalLM): quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + peft_model_id: str = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -35,14 +38,6 @@ class FlashRWSharded(FlashCausalLM): else: raise NotImplementedError("FlashRW is only available on GPU") - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - config = RWConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) @@ -61,7 +56,43 @@ class FlashRWSharded(FlashCausalLM): if config.quantize == "gptq": weights._set_gptq_params(model_id) - model = FlashRWForCausalLM(config, weights) + if peft_model_id: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + padding_side="left", + ) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + llm_int8_enable_fp32_cpu_offload=True, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + trust_remote_code=True, + device_map="auto", + quantization_config=bnb_config + ) + + # load Lora weights + model = PeftModel.from_pretrained( + model, + peft_model_id, + device_map="auto", + ) + model.eval() + else: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + model = FlashRWForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) super(FlashRWSharded, self).__init__( diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index d97c1c73..d5e3ef07 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -1,6 +1,8 @@ import torch from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import BitsAndBytesConfig +from peft import PeftModel from typing import List, Optional, Tuple from text_generation_server.models import CausalLM @@ -14,6 +16,7 @@ class RW(CausalLM): quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + peft_model_id: str = None, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -25,23 +28,52 @@ class RW(CausalLM): device = torch.device("cpu") dtype = torch.float32 - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map="auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) + if peft_model_id: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + padding_side="left", + ) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + llm_int8_enable_fp32_cpu_offload=True, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + trust_remote_code=True, + device_map="auto", + quantization_config=bnb_config + ) + + assert 1 + # load Lora weights + model = PeftModel.from_pretrained( + model, + peft_model_id, + device_map="auto", + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 67137aaa..0eb30d65 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -123,7 +123,9 @@ def serve( dtype: Optional[str], trust_remote_code: bool, uds_path: Path, + peft_model_id: str ): + logger.info(f"Receiving peft {peft_model_id=}") async def serve_inner( model_id: str, revision: Optional[str], @@ -131,6 +133,7 @@ def serve( quantize: Optional[str] = None, dtype: Optional[str] = None, trust_remote_code: bool = False, + peft_model_id: str = None, ): unix_socket_template = "unix://{}-{}" if sharded: @@ -144,8 +147,9 @@ def serve( server_urls = [local_url] try: + logger.exception(f"In server {peft_model_id=}") model = get_model( - model_id, revision, sharded, quantize, dtype, trust_remote_code + model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id ) except Exception: logger.exception("Error when initializing model") @@ -193,5 +197,5 @@ def serve( await server.stop(0) asyncio.run( - serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id) ) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 23743c9b..d9b18abc 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -86,6 +86,7 @@ def weight_files( # Local model if Path(model_id).exists() and Path(model_id).is_dir(): local_files = list(Path(model_id).glob(f"*{extension}")) + logger.info(f"Local files {local_files=}") if not local_files: raise FileNotFoundError( f"No local weights found in {model_id} with extension {extension}" diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index be1f9444..7ca03b41 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -31,7 +31,6 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): logger.info(f"Merging the lora weights.") base_model_id = model.peft_config["default"].base_model_name_or_path - model = model.merge_and_unload() os.makedirs(model_id, exist_ok=True)