works end-to-end

This commit is contained in:
Geoffrey Angus 2023-08-15 15:20:57 -07:00
parent 034e39185f
commit ab0937b90c
7 changed files with 200 additions and 23 deletions

View File

@ -217,7 +217,5 @@ fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
// Truncate to sequence_length // Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left); encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode // Decode
tokenizer tokenizer.decode(encoding.get_ids(), false).unwrap()
.decode(Vec::from(encoding.get_ids()), false)
.unwrap()
} }

View File

@ -72,6 +72,14 @@ struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)] #[clap(default_value = "bigscience/bloom-560m", long, env)]
model_id: String, model_id: String,
/// The name of the adapter to load.
/// Can be a MODEL_ID as listed on <https://hf.co/models>
/// or it can be a local directory containing the necessary files
/// as saved by `save_pretrained(...)` methods of transformers.
/// Should be compatible with the model specified in `model_id`.
#[clap(default_value = "", long, env)]
adapter_id: String,
/// The actual revision of the model if you're referring to a model /// The actual revision of the model if you're referring to a model
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
#[clap(long, env)] #[clap(long, env)]
@ -290,6 +298,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn shard_manager( fn shard_manager(
model_id: String, model_id: String,
adapter_id: String,
revision: Option<String>, revision: Option<String>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
dtype: Option<Dtype>, dtype: Option<Dtype>,
@ -332,6 +341,12 @@ fn shard_manager(
"--json-output".to_string(), "--json-output".to_string(),
]; ];
// Check if adapter id is non-empty string
if !adapter_id.is_empty() {
shard_args.push("--adapter-id".to_string());
shard_args.push(adapter_id);
}
// Activate trust remote code // Activate trust remote code
if trust_remote_code { if trust_remote_code {
shard_args.push("--trust-remote-code".to_string()); shard_args.push("--trust-remote-code".to_string());
@ -639,13 +654,13 @@ enum LauncherError {
WebserverCannotStart, WebserverCannotStart,
} }
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(model_id: String, args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
// Enter download tracing span // Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let _span = tracing::span!(tracing::Level::INFO, "download").entered();
let mut download_args = vec![ let mut download_args = vec![
"download-weights".to_string(), "download-weights".to_string(),
args.model_id.to_string(), model_id,
"--extension".to_string(), "--extension".to_string(),
".safetensors".to_string(), ".safetensors".to_string(),
"--logger-level".to_string(), "--logger-level".to_string(),
@ -767,6 +782,7 @@ fn spawn_shards(
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
let model_id = args.model_id.clone(); let model_id = args.model_id.clone();
let adapter_id = args.adapter_id.clone();
let revision = args.revision.clone(); let revision = args.revision.clone();
let uds_path = args.shard_uds_path.clone(); let uds_path = args.shard_uds_path.clone();
let master_addr = args.master_addr.clone(); let master_addr = args.master_addr.clone();
@ -787,6 +803,7 @@ fn spawn_shards(
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
adapter_id,
revision, revision,
quantize, quantize,
dtype, dtype,
@ -1081,7 +1098,13 @@ fn main() -> Result<(), LauncherError> {
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Download and convert model weights // Download and convert model weights
download_convert_model(&args, running.clone())?; download_convert_model(args.model_id.to_string(), &args, running.clone())?;
// check if adapter_id is non-empty string
if !args.adapter_id.is_empty() {
download_convert_model(args.adapter_id.to_string(), &args, running.clone())?;
}
if !running.load(Ordering::SeqCst) { if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop // Launcher was asked to stop

View File

@ -311,7 +311,7 @@ fn prepare_input(
// truncate encoding and decode new inputs // truncate encoding and decode new inputs
encoding.truncate(truncate, 0, TruncationDirection::Left); encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer let inputs = tokenizer
.decode(Vec::from(encoding.get_ids()), false) .decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len()) (inputs, encoding.len())
} }

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from loguru import logger
from opentelemetry import trace from opentelemetry import trace
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from typing import Optional from typing import Optional
@ -11,6 +12,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
LlamaConfig, LlamaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
create_merged_weight_files,
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
@ -29,15 +31,6 @@ class FlashLlama(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
print("ASDFASDF FLASHLLAMA INIT")
print("Args:")
print(f"model_id: {model_id}")
print(f"adapter_id: {adapter_id}")
print(f"revision: {revision}")
print(f"quantize: {quantize}")
print(f"dtype: {dtype}")
print(f"trust_remote_code: {trust_remote_code}")
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
@ -70,7 +63,23 @@ class FlashLlama(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
# get adapter filenames if adapter_id passed in
merged_weight_filenames = None
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames
)
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.adapter import create_merged_weight_files
from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
@ -20,6 +21,7 @@ from text_generation_server.utils.tokens import (
) )
__all__ = [ __all__ = [
"create_merged_weight_files",
"convert_file", "convert_file",
"convert_files", "convert_files",
"initialize_torch_distributed", "initialize_torch_distributed",

View File

@ -0,0 +1,127 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import List, Dict, Set, Tuple
import torch
from loguru import logger
from peft import LoraConfig
from peft.utils import transpose
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from text_generation_server.utils.hub import weight_files
def compute_delta_weight(
lora_A: torch.Tensor,
lora_B: torch.Tensor,
fan_in_fan_out: bool,
alpha: float,
r: float
) -> torch.Tensor:
"""Computes the delta weight for a Linear layer given A and B LoRA matrices.
TODO: add logic for other module types beyond Linear layers.
Reference: https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L799-L806
"""
scaling = alpha / r
delta_weight = transpose(lora_B @ lora_A, fan_in_fan_out) * scaling
return delta_weight
def merge_adapter_weights(
model_weights: Dict[str, torch.Tensor],
adapter_weights: Dict[str, torch.Tensor],
adapter_config: LoraConfig
) -> Tuple[Dict[str, torch.Tensor], Set[str]]:
"""Merges the adapter weights into the model weights."""
module_mapping = defaultdict(dict)
processed_adapter_weight_names = set()
# map the original tensor names to their adapter counterparts
for weight_name in model_weights:
end_idx = weight_name.rfind(".weight")
key = weight_name[:end_idx]
for adapter_weight_name in adapter_weights:
if key in adapter_weight_name:
# example value: 'base_model.model.model.layers.10.self_attn.v_proj.lora_B.weight'
# matrix_type gets the second to last element in the module name, i.e. 'lora_B'
matrix_type = adapter_weight_name.split(".")[-2]
module_mapping[weight_name][matrix_type] = adapter_weight_name
processed_adapter_weight_names.add(adapter_weight_name)
# merge adapter weights into model weights
merged_weights = {}
for weight_name, adapter_weight_names in tqdm(
module_mapping.items(), desc="Merging adapter weights", total=len(module_mapping)):
# TODO: support adapter types beyond LoRA
lora_A = adapter_weights[adapter_weight_names["lora_A"]]
lora_B = adapter_weights[adapter_weight_names["lora_B"]]
delta_weight = compute_delta_weight(
lora_A, lora_B, adapter_config.fan_in_fan_out, adapter_config.lora_alpha, adapter_config.r)
merged_weights[weight_name] = model_weights[weight_name] + delta_weight
return merged_weights, processed_adapter_weight_names
def create_merged_weight_files(
adapter_id: str,
model_id: str,
model_weight_filenames: List[Path]
) -> List[Path]:
"""Creates merged weight files for the given adapter ID and filenames."""
adapter_filenames = weight_files(adapter_id, extension=".safetensors")
adapter_config = LoraConfig.from_pretrained(adapter_id)
if adapter_config.base_model_name_or_path != model_id:
raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}")
# load adapter weights from all shards (should have relatively small memory footprint)
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
remaining_adapter_weight_names = set(adapter_weights.keys())
merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/"
# just grab the existing files if they already exist and return immediately
if os.path.exists(merged_weight_directory):
logger.info("Merged weight files already exist, skipping merge computation.")
return weight_files(merged_weight_directory)
os.makedirs(merged_weight_directory)
merged_weight_filenames = []
for filename in model_weight_filenames:
model_weights = load_file(filename)
merged_weights, processed_adapter_weight_names = merge_adapter_weights(
model_weights, adapter_weights, adapter_config)
merged_adapter_filename = Path(merged_weight_directory, os.path.basename(filename))
save_file(merged_weights, merged_adapter_filename)
logger.debug(f"Saved merged weights into {merged_adapter_filename}")
merged_weight_filenames.append(merged_adapter_filename)
remaining_adapter_weight_names = remaining_adapter_weight_names.difference(
processed_adapter_weight_names)
if len(remaining_adapter_weight_names) > 0:
logger.warning("WARNING: The following lora weights were not merged into the model weights:")
for lora_name in remaining_adapter_weight_names:
logger.warning("\t" + lora_name)
return merged_weight_filenames
def main():
adapter_id = "arnavgrg/codealpaca-qlora"
adapter_config = LoraConfig.from_pretrained(adapter_id)
model_id = adapter_config.base_model_name_or_path
model_weight_filenames = weight_files(model_id, extension=".safetensors")
merged_adapter_filenames = create_merged_weight_files(adapter_id, model_id, model_weight_filenames)
print(merged_adapter_filenames)
if __name__ == '__main__':
main()

View File

@ -15,19 +15,37 @@ class Weights:
dtype, dtype,
process_group, process_group,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
merged_weight_filenames: Optional[List] = None,
): ):
# idea: maybe we can pass in adapter filenames here and have these take # routes to adapter files take precedence over routes to main model files
# precedence over the model filenames? If so, then self.routing would # to ensure that adapter weights are loaded instead of main model weights
# just handle the mapping of tensor names to filenames.
routing = {} routing = {}
if merged_weight_filenames is not None:
for filename in merged_weight_filenames:
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple adapter files: {filename} and {routing[k]}"
)
routing[k] = filename
# set of keys that point to adapter files. Duplicates for these keys found
# in main model files will be overridden.
adapter_routes = set(routing.keys())
for filename in filenames: for filename in filenames:
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
for k in f.keys(): for k in f.keys():
if k in routing: if k in adapter_routes:
logger.debug(f"Overriding main model weights with adapter weights for key: {k}")
elif k in routing:
raise RuntimeError( raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}" f"Key {k} was found in multiple non-adapter files: {filename} and {routing[k]}"
) )
routing[k] = filename else:
routing[k] = filename
if aliases is None: if aliases is None:
aliases = {} aliases = {}
self.aliases = aliases self.aliases = aliases