diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b57c652b..c72d31d3 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -217,7 +217,5 @@ fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String { // Truncate to sequence_length encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left); // Decode - tokenizer - .decode(Vec::from(encoding.get_ids()), false) - .unwrap() + tokenizer.decode(encoding.get_ids(), false).unwrap() } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2ad788a4..0162a73e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -72,6 +72,14 @@ struct Args { #[clap(default_value = "bigscience/bloom-560m", long, env)] model_id: String, + /// The name of the adapter to load. + /// Can be a MODEL_ID as listed on + /// or it can be a local directory containing the necessary files + /// as saved by `save_pretrained(...)` methods of transformers. + /// Should be compatible with the model specified in `model_id`. + #[clap(default_value = "", long, env)] + adapter_id: String, + /// The actual revision of the model if you're referring to a model /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. #[clap(long, env)] @@ -290,6 +298,7 @@ enum ShardStatus { #[allow(clippy::too_many_arguments)] fn shard_manager( model_id: String, + adapter_id: String, revision: Option, quantize: Option, dtype: Option, @@ -332,6 +341,12 @@ fn shard_manager( "--json-output".to_string(), ]; + // Check if adapter id is non-empty string + if !adapter_id.is_empty() { + shard_args.push("--adapter-id".to_string()); + shard_args.push(adapter_id); + } + // Activate trust remote code if trust_remote_code { shard_args.push("--trust-remote-code".to_string()); @@ -639,13 +654,13 @@ enum LauncherError { WebserverCannotStart, } -fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { +fn download_convert_model(model_id: String, args: &Args, running: Arc) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let mut download_args = vec![ "download-weights".to_string(), - args.model_id.to_string(), + model_id, "--extension".to_string(), ".safetensors".to_string(), "--logger-level".to_string(), @@ -767,6 +782,7 @@ fn spawn_shards( // Start shard processes for rank in 0..num_shard { let model_id = args.model_id.clone(); + let adapter_id = args.adapter_id.clone(); let revision = args.revision.clone(); let uds_path = args.shard_uds_path.clone(); let master_addr = args.master_addr.clone(); @@ -787,6 +803,7 @@ fn spawn_shards( thread::spawn(move || { shard_manager( model_id, + adapter_id, revision, quantize, dtype, @@ -1081,7 +1098,13 @@ fn main() -> Result<(), LauncherError> { .expect("Error setting Ctrl-C handler"); // Download and convert model weights - download_convert_model(&args, running.clone())?; + download_convert_model(args.model_id.to_string(), &args, running.clone())?; + + // check if adapter_id is non-empty string + if !args.adapter_id.is_empty() { + download_convert_model(args.adapter_id.to_string(), &args, running.clone())?; + } + if !running.load(Ordering::SeqCst) { // Launcher was asked to stop diff --git a/router/src/validation.rs b/router/src/validation.rs index be835bf0..f967361f 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -311,7 +311,7 @@ fn prepare_input( // truncate encoding and decode new inputs encoding.truncate(truncate, 0, TruncationDirection::Left); let inputs = tokenizer - .decode(Vec::from(encoding.get_ids()), false) + .decode(encoding.get_ids(), false) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; (inputs, encoding.len()) } diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 01d1ca6a..5cfa93b3 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -1,6 +1,7 @@ import torch import torch.distributed +from loguru import logger from opentelemetry import trace from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from typing import Optional @@ -11,6 +12,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import ( LlamaConfig, ) from text_generation_server.utils import ( + create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -29,15 +31,6 @@ class FlashLlama(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - print("ASDFASDF FLASHLLAMA INIT") - print("Args:") - print(f"model_id: {model_id}") - print(f"adapter_id: {adapter_id}") - print(f"revision: {revision}") - print(f"quantize: {quantize}") - print(f"dtype: {dtype}") - print(f"trust_remote_code: {trust_remote_code}") - self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -70,7 +63,23 @@ class FlashLlama(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + + # get adapter filenames if adapter_id passed in + merged_weight_filenames = None + if len(adapter_id) > 0: + logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") + merged_weight_filenames = create_merged_weight_files( + adapter_id, model_id, model_weight_filenames=filenames + ) + + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames + ) + if config.quantize == "gptq": weights._set_gptq_params(model_id) diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index befedcf0..d443ce20 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -1,3 +1,4 @@ +from text_generation_server.utils.adapter import create_merged_weight_files from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.weights import Weights @@ -20,6 +21,7 @@ from text_generation_server.utils.tokens import ( ) __all__ = [ + "create_merged_weight_files", "convert_file", "convert_files", "initialize_torch_distributed", diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py new file mode 100644 index 00000000..9f6d4960 --- /dev/null +++ b/server/text_generation_server/utils/adapter.py @@ -0,0 +1,127 @@ +import os +from collections import defaultdict +from pathlib import Path +from typing import List, Dict, Set, Tuple + +import torch +from loguru import logger +from peft import LoraConfig +from peft.utils import transpose +from safetensors.torch import load_file, save_file +from tqdm import tqdm + +from text_generation_server.utils.hub import weight_files + + +def compute_delta_weight( + lora_A: torch.Tensor, + lora_B: torch.Tensor, + fan_in_fan_out: bool, + alpha: float, + r: float +) -> torch.Tensor: + """Computes the delta weight for a Linear layer given A and B LoRA matrices. + + TODO: add logic for other module types beyond Linear layers. + + Reference: https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L799-L806 + """ + scaling = alpha / r + delta_weight = transpose(lora_B @ lora_A, fan_in_fan_out) * scaling + return delta_weight + + +def merge_adapter_weights( + model_weights: Dict[str, torch.Tensor], + adapter_weights: Dict[str, torch.Tensor], + adapter_config: LoraConfig +) -> Tuple[Dict[str, torch.Tensor], Set[str]]: + """Merges the adapter weights into the model weights.""" + module_mapping = defaultdict(dict) + processed_adapter_weight_names = set() + + # map the original tensor names to their adapter counterparts + for weight_name in model_weights: + end_idx = weight_name.rfind(".weight") + key = weight_name[:end_idx] + for adapter_weight_name in adapter_weights: + if key in adapter_weight_name: + # example value: 'base_model.model.model.layers.10.self_attn.v_proj.lora_B.weight' + # matrix_type gets the second to last element in the module name, i.e. 'lora_B' + matrix_type = adapter_weight_name.split(".")[-2] + module_mapping[weight_name][matrix_type] = adapter_weight_name + processed_adapter_weight_names.add(adapter_weight_name) + + # merge adapter weights into model weights + merged_weights = {} + for weight_name, adapter_weight_names in tqdm( + module_mapping.items(), desc="Merging adapter weights", total=len(module_mapping)): + + # TODO: support adapter types beyond LoRA + lora_A = adapter_weights[adapter_weight_names["lora_A"]] + lora_B = adapter_weights[adapter_weight_names["lora_B"]] + delta_weight = compute_delta_weight( + lora_A, lora_B, adapter_config.fan_in_fan_out, adapter_config.lora_alpha, adapter_config.r) + merged_weights[weight_name] = model_weights[weight_name] + delta_weight + return merged_weights, processed_adapter_weight_names + + +def create_merged_weight_files( + adapter_id: str, + model_id: str, + model_weight_filenames: List[Path] +) -> List[Path]: + """Creates merged weight files for the given adapter ID and filenames.""" + adapter_filenames = weight_files(adapter_id, extension=".safetensors") + + adapter_config = LoraConfig.from_pretrained(adapter_id) + if adapter_config.base_model_name_or_path != model_id: + raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}") + + # load adapter weights from all shards (should have relatively small memory footprint) + adapter_weights = {} + for filename in adapter_filenames: + adapter_weights.update(load_file(filename)) + remaining_adapter_weight_names = set(adapter_weights.keys()) + + merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/" + # just grab the existing files if they already exist and return immediately + if os.path.exists(merged_weight_directory): + logger.info("Merged weight files already exist, skipping merge computation.") + return weight_files(merged_weight_directory) + + os.makedirs(merged_weight_directory) + merged_weight_filenames = [] + for filename in model_weight_filenames: + model_weights = load_file(filename) + merged_weights, processed_adapter_weight_names = merge_adapter_weights( + model_weights, adapter_weights, adapter_config) + + merged_adapter_filename = Path(merged_weight_directory, os.path.basename(filename)) + save_file(merged_weights, merged_adapter_filename) + logger.debug(f"Saved merged weights into {merged_adapter_filename}") + + merged_weight_filenames.append(merged_adapter_filename) + remaining_adapter_weight_names = remaining_adapter_weight_names.difference( + processed_adapter_weight_names) + + if len(remaining_adapter_weight_names) > 0: + logger.warning("WARNING: The following lora weights were not merged into the model weights:") + for lora_name in remaining_adapter_weight_names: + logger.warning("\t" + lora_name) + + return merged_weight_filenames + + +def main(): + adapter_id = "arnavgrg/codealpaca-qlora" + adapter_config = LoraConfig.from_pretrained(adapter_id) + model_id = adapter_config.base_model_name_or_path + model_weight_filenames = weight_files(model_id, extension=".safetensors") + + merged_adapter_filenames = create_merged_weight_files(adapter_id, model_id, model_weight_filenames) + print(merged_adapter_filenames) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index d4f5f0b6..d2fe57a3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -15,19 +15,37 @@ class Weights: dtype, process_group, aliases: Optional[Dict[str, List[str]]] = None, + merged_weight_filenames: Optional[List] = None, ): - # idea: maybe we can pass in adapter filenames here and have these take - # precedence over the model filenames? If so, then self.routing would - # just handle the mapping of tensor names to filenames. + # routes to adapter files take precedence over routes to main model files + # to ensure that adapter weights are loaded instead of main model weights routing = {} + if merged_weight_filenames is not None: + for filename in merged_weight_filenames: + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple adapter files: {filename} and {routing[k]}" + ) + routing[k] = filename + + # set of keys that point to adapter files. Duplicates for these keys found + # in main model files will be overridden. + adapter_routes = set(routing.keys()) + for filename in filenames: with safe_open(filename, framework="pytorch") as f: for k in f.keys(): - if k in routing: + if k in adapter_routes: + logger.debug(f"Overriding main model weights with adapter weights for key: {k}") + elif k in routing: raise RuntimeError( - f"Key {k} was found in multiple files: {filename} and {routing[k]}" + f"Key {k} was found in multiple non-adapter files: {filename} and {routing[k]}" ) - routing[k] = filename + else: + routing[k] = filename + if aliases is None: aliases = {} self.aliases = aliases